From d750a5773334c0689e62c2743b17adbdf57e1df3 Mon Sep 17 00:00:00 2001 From: krichard1212 <136473183+krichard1212@users.noreply.github.com> Date: Tue, 27 Jan 2026 11:53:42 +0100 Subject: [PATCH 01/14] Agent status --- pkg/debug/.dockerignore | 10 + pkg/debug/build-legacy.Dockerfile | 62 ++ pkg/debug/build.Dockerfile | 59 ++ pkg/debug/diagnostic.sh | 1038 +++++++++++++++++++++++++++++ pkg/debug/local_build.sh | 22 + src/agent_status.rs | 73 ++ src/bpf_stats_noop.rs | 236 +++++++ src/bpf_stub.rs | 29 + src/firewall_noop.rs | 290 ++++++++ src/main.rs | 124 ++++ src/utils/bpf_utils_noop.rs | 16 + src/utils/tcp_fingerprint_noop.rs | 166 +++++ src/worker/agent_status.rs | 57 ++ src/worker/log.rs | 5 + src/worker/mod.rs | 1 + 15 files changed, 2188 insertions(+) create mode 100644 pkg/debug/.dockerignore create mode 100644 pkg/debug/build-legacy.Dockerfile create mode 100644 pkg/debug/build.Dockerfile create mode 100644 pkg/debug/diagnostic.sh create mode 100644 pkg/debug/local_build.sh create mode 100644 src/agent_status.rs create mode 100644 src/bpf_stats_noop.rs create mode 100644 src/bpf_stub.rs create mode 100644 src/firewall_noop.rs create mode 100644 src/utils/bpf_utils_noop.rs create mode 100644 src/utils/tcp_fingerprint_noop.rs create mode 100644 src/worker/agent_status.rs diff --git a/pkg/debug/.dockerignore b/pkg/debug/.dockerignore new file mode 100644 index 0000000..dd3f5fa --- /dev/null +++ b/pkg/debug/.dockerignore @@ -0,0 +1,10 @@ +target +.git +.vscode +helm +docs +.github +.devcontainer +images +*.md +docker/volumes diff --git a/pkg/debug/build-legacy.Dockerfile b/pkg/debug/build-legacy.Dockerfile new file mode 100644 index 0000000..ce004f4 --- /dev/null +++ b/pkg/debug/build-legacy.Dockerfile @@ -0,0 +1,62 @@ +ARG IMAGE="ubuntu" +ARG IMAGE_TAG="16.04" +ARG BUILD_FLAGS + +FROM ${IMAGE}:${IMAGE_TAG} + +RUN sed -i '/updates/d' /etc/apt/sources.list && \ + sed -i 's/httpredir/archive/' /etc/apt/sources.list && \ + sed -i 's|https://|http://|g' /etc/apt/sources.list + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + apt-transport-https \ + ca-certificates \ + curl \ + gnupg \ + lsb-release && \ + curl -fsSL https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - && \ + echo "deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-10 main" >> /etc/apt/sources.list.d/llvm.list && \ + apt-get update && \ + apt-get install -y --no-install-recommends \ + libc6 \ + libc6-dev \ + g++ \ + gcc \ + make && \ + apt-get install -y --no-install-recommends \ + git \ + build-essential \ + clang-10 \ + llvm-10 \ + libelf-dev \ + libelf1 \ + libssl-dev \ + zlib1g-dev \ + libzstd-dev \ + pkg-config \ + libcap-dev \ + binutils-multiarch-dev \ + cmake && \ + update-alternatives --install /usr/bin/clang clang /usr/bin/clang-10 100 && \ + update-alternatives --install /usr/bin/llc llc /usr/bin/llc-10 100 && \ + rm -rf /var/lib/apt/lists/* + +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \ + . "$HOME/.cargo/env" && \ + rustup default stable && \ + rustup update stable +ENV PATH="/root/.cargo/bin:${PATH}" + +WORKDIR /app + +COPY . . + +# If BUILD_FLAGS is unset or empty, default to --no-default-features (no eBPF). +RUN cargo build --release ${BUILD_FLAGS:---no-default-features} + +# Create output directory and copy binary +RUN mkdir -p /output && \ + cp target/release/synapse /output/synapse + +VOLUME ["/output"] diff --git a/pkg/debug/build.Dockerfile b/pkg/debug/build.Dockerfile new file mode 100644 index 0000000..23d8249 --- /dev/null +++ b/pkg/debug/build.Dockerfile @@ -0,0 +1,59 @@ +ARG IMAGE="ubuntu" +ARG IMAGE_TAG="18.04" +ARG BUILD_FLAGS="" + +FROM ${IMAGE}:${IMAGE_TAG} + +RUN sed -i '/updates/d' /etc/apt/sources.list && \ + sed -i 's/httpredir/archive/' /etc/apt/sources.list + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + ca-certificates \ + curl \ + gnupg \ + lsb-release && \ + curl -fsSL https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - && \ + echo "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-10 main" >> /etc/apt/sources.list.d/llvm.list && \ + apt-get update && \ + apt-get install -y --no-install-recommends --allow-downgrades \ + libc6=2.27-3ubuntu1.5 \ + libc6-dev \ + g++ \ + gcc \ + make && \ + apt-get install -y --no-install-recommends \ + git \ + build-essential \ + clang-10 \ + llvm-10 \ + libelf-dev \ + libelf1 \ + libssl-dev \ + zlib1g-dev \ + libzstd-dev \ + pkg-config \ + libcap-dev \ + binutils-multiarch-dev \ + cmake && \ + update-alternatives --install /usr/bin/clang clang /usr/bin/clang-10 100 && \ + update-alternatives --install /usr/bin/llc llc /usr/bin/llc-10 100 && \ + rm -rf /var/lib/apt/lists/* + +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \ + . "$HOME/.cargo/env" && \ + rustup default stable && \ + rustup update stable +ENV PATH="/root/.cargo/bin:${PATH}" + +WORKDIR /app + +COPY . . + +RUN cargo build --release $BUILD_FLAGS + +# Create output directory and copy binary +RUN mkdir -p /output && \ + cp target/release/synapse /output/synapse + +VOLUME ["/output"] diff --git a/pkg/debug/diagnostic.sh b/pkg/debug/diagnostic.sh new file mode 100644 index 0000000..1d3b7ff --- /dev/null +++ b/pkg/debug/diagnostic.sh @@ -0,0 +1,1038 @@ +#!/bin/bash +# +# Synapse Diagnostic Script +# Checks system capabilities for running moat with all its features +# + +set -e + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +CYAN='\033[0;36m' +NC='\033[0m' # No Color +BOLD='\033[1m' + +# Feature tracking +declare -A FEATURES +FEATURES["xdp"]=0 +FEATURES["xdp_hardware"]=0 +FEATURES["xdp_driver"]=0 +FEATURES["xdp_skb"]=0 +FEATURES["nftables"]=0 +FEATURES["iptables"]=0 +FEATURES["ip6tables"]=0 +FEATURES["bpf_syscall"]=0 +FEATURES["bpf_jit"]=0 +FEATURES["btf"]=0 +FEATURES["ipv6"]=0 + +# Minimum requirements +MIN_KERNEL_MAJOR=4 +MIN_KERNEL_MINOR=18 +MIN_GLIBC_MAJOR=2 +MIN_GLIBC_MINOR=17 + +print_header() { + echo "" + echo -e "${BOLD}${BLUE}========================================${NC}" + echo -e "${BOLD}${BLUE} $1${NC}" + echo -e "${BOLD}${BLUE}========================================${NC}" +} + +print_section() { + echo "" + echo -e "${CYAN}--- $1 ---${NC}" +} + +print_ok() { + echo -e " ${GREEN}[OK]${NC} $1" +} + +print_warn() { + echo -e " ${YELLOW}[WARN]${NC} $1" +} + +print_fail() { + echo -e " ${RED}[FAIL]${NC} $1" +} + +print_info() { + echo -e " ${BLUE}[INFO]${NC} $1" +} + +version_compare() { + # Returns 0 if $1 >= $2 + local v1_major v1_minor v2_major v2_minor + v1_major=$(echo "$1" | cut -d. -f1) + v1_minor=$(echo "$1" | cut -d. -f2) + v2_major=$3 + v2_minor=$4 + + if [ "$v1_major" -gt "$v2_major" ]; then + return 0 + elif [ "$v1_major" -eq "$v2_major" ] && [ "$v1_minor" -ge "$v2_minor" ]; then + return 0 + fi + return 1 +} + +print_header "Synapse System Diagnostic" +echo -e "Running diagnostics at: $(date)" +echo -e "Hostname: $(hostname)" + +# ============================================================================= +# LINUX DISTRIBUTION CHECK +# ============================================================================= +print_section "Linux Distribution" + +DISTRO_NAME="Unknown" +DISTRO_VERSION="Unknown" +DISTRO_ID="unknown" + +# Try /etc/os-release first (most modern distributions) +if [ -f /etc/os-release ]; then + . /etc/os-release + DISTRO_NAME="${NAME:-Unknown}" + DISTRO_VERSION="${VERSION:-${VERSION_ID:-Unknown}}" + DISTRO_ID="${ID:-unknown}" + DISTRO_ID_LIKE="${ID_LIKE:-}" + DISTRO_PRETTY="${PRETTY_NAME:-$DISTRO_NAME $DISTRO_VERSION}" + + print_info "Distribution: $DISTRO_PRETTY" + if [ -n "$DISTRO_ID_LIKE" ]; then + print_info "Based on: $DISTRO_ID_LIKE" + fi +# Try /etc/lsb-release (older Ubuntu/Debian) +elif [ -f /etc/lsb-release ]; then + . /etc/lsb-release + DISTRO_NAME="${DISTRIB_ID:-Unknown}" + DISTRO_VERSION="${DISTRIB_RELEASE:-Unknown}" + DISTRO_ID=$(echo "$DISTRO_NAME" | tr '[:upper:]' '[:lower:]') + + print_info "Distribution: $DISTRO_NAME $DISTRO_VERSION" + if [ -n "${DISTRIB_CODENAME:-}" ]; then + print_info "Codename: $DISTRIB_CODENAME" + fi +# Try specific release files +elif [ -f /etc/redhat-release ]; then + DISTRO_PRETTY=$(cat /etc/redhat-release) + DISTRO_ID="rhel" + print_info "Distribution: $DISTRO_PRETTY" +elif [ -f /etc/debian_version ]; then + DISTRO_VERSION=$(cat /etc/debian_version) + DISTRO_NAME="Debian" + DISTRO_ID="debian" + print_info "Distribution: Debian $DISTRO_VERSION" +elif [ -f /etc/alpine-release ]; then + DISTRO_VERSION=$(cat /etc/alpine-release) + DISTRO_NAME="Alpine Linux" + DISTRO_ID="alpine" + print_info "Distribution: Alpine Linux $DISTRO_VERSION" +elif [ -f /etc/arch-release ]; then + DISTRO_NAME="Arch Linux" + DISTRO_ID="arch" + print_info "Distribution: Arch Linux (rolling release)" +elif [ -f /etc/gentoo-release ]; then + DISTRO_PRETTY=$(cat /etc/gentoo-release) + DISTRO_ID="gentoo" + print_info "Distribution: $DISTRO_PRETTY" +elif [ -f /etc/SuSE-release ]; then + DISTRO_PRETTY=$(head -1 /etc/SuSE-release) + DISTRO_ID="suse" + print_info "Distribution: $DISTRO_PRETTY" +else + print_warn "Could not detect Linux distribution" +fi + +# Architecture +ARCH=$(uname -m) +print_info "Architecture: $ARCH" + +# Get kernel version early (needed for distribution-specific checks) +KERNEL_VERSION=$(uname -r) +KERNEL_MAJOR=$(echo "$KERNEL_VERSION" | cut -d. -f1) +KERNEL_MINOR=$(echo "$KERNEL_VERSION" | cut -d. -f2) + +# ============================================================================= +# HARDWARE / VIRTUALIZATION CHECK +# ============================================================================= +print_section "Hardware / Virtualization" + +IS_PHYSICAL=true +IS_CONTAINER=false +IS_VM=false +VIRT_TYPE="" +CONTAINER_TYPE="" +HYPERVISOR="" + +# Check if running in container first +if [ -f /.dockerenv ]; then + IS_CONTAINER=true + CONTAINER_TYPE="Docker" +elif [ -f /run/.containerenv ]; then + IS_CONTAINER=true + CONTAINER_TYPE="Podman" +elif grep -q "container=lxc" /proc/1/environ 2>/dev/null; then + IS_CONTAINER=true + CONTAINER_TYPE="LXC" +elif grep -q "/docker/" /proc/1/cgroup 2>/dev/null; then + IS_CONTAINER=true + CONTAINER_TYPE="Docker" +elif grep -q "/lxc/" /proc/1/cgroup 2>/dev/null; then + IS_CONTAINER=true + CONTAINER_TYPE="LXC" +elif grep -q "/kubepods/" /proc/1/cgroup 2>/dev/null; then + IS_CONTAINER=true + CONTAINER_TYPE="Kubernetes Pod" +elif [ -d /run/systemd/system ] && systemd-detect-virt --container &>/dev/null; then + DETECTED=$(systemd-detect-virt --container 2>/dev/null) + if [ "$DETECTED" != "none" ] && [ -n "$DETECTED" ]; then + IS_CONTAINER=true + CONTAINER_TYPE="$DETECTED" + fi +fi + +# Detect virtualization technology using multiple methods +detect_virtualization() { + # Method 1: systemd-detect-virt (most reliable on systemd systems) + if command -v systemd-detect-virt &> /dev/null; then + DETECTED=$(systemd-detect-virt --vm 2>/dev/null) || true + if [ -n "$DETECTED" ] && [ "$DETECTED" != "none" ]; then + IS_VM=true + IS_PHYSICAL=false + VIRT_TYPE="$DETECTED" + return 0 + fi + fi + + # Method 2: Check DMI/SMBIOS information (may not exist on ARM) + if [ -f /sys/class/dmi/id/product_name ]; then + PRODUCT_NAME=$(cat /sys/class/dmi/id/product_name 2>/dev/null) || true + case "$PRODUCT_NAME" in + *VirtualBox*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="VirtualBox"; return 0 ;; + *VMware*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="VMware"; return 0 ;; + *Virtual\ Machine*|*Hyper-V*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Hyper-V"; return 0 ;; + *KVM*|*QEMU*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="KVM/QEMU"; return 0 ;; + *Bochs*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Bochs"; return 0 ;; + *Parallels*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Parallels"; return 0 ;; + esac + fi + + # Method 3: Check sys_vendor + if [ -f /sys/class/dmi/id/sys_vendor ]; then + SYS_VENDOR=$(cat /sys/class/dmi/id/sys_vendor 2>/dev/null) || true + case "$SYS_VENDOR" in + *VMware*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="VMware"; return 0 ;; + *innotek*|*Oracle*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="VirtualBox"; return 0 ;; + *Xen*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Xen"; return 0 ;; + *Microsoft*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Hyper-V"; return 0 ;; + *QEMU*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="QEMU"; return 0 ;; + *Amazon\ EC2*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Amazon EC2 (Xen/Nitro)"; return 0 ;; + *Google*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Google Cloud"; return 0 ;; + *DigitalOcean*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="DigitalOcean"; return 0 ;; + esac + fi + + # Method 4: Check board_vendor + if [ -f /sys/class/dmi/id/board_vendor ]; then + BOARD_VENDOR=$(cat /sys/class/dmi/id/board_vendor 2>/dev/null) || true + case "$BOARD_VENDOR" in + *Amazon*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Amazon EC2"; return 0 ;; + esac + fi + + # Method 5: Check /proc/cpuinfo for hypervisor flag (x86/x86_64) + if grep -q "^flags.*hypervisor" /proc/cpuinfo 2>/dev/null; then + IS_VM=true + IS_PHYSICAL=false + # Try to identify which hypervisor + if grep -q "^flags.*vmx" /proc/cpuinfo 2>/dev/null; then + VIRT_TYPE="Unknown (nested virt with VMX)" + elif grep -q "^flags.*svm" /proc/cpuinfo 2>/dev/null; then + VIRT_TYPE="Unknown (nested virt with SVM)" + else + VIRT_TYPE="Unknown hypervisor" + fi + return 0 + fi + + # Method 6: Check for Xen via /sys/hypervisor + if [ -f /sys/hypervisor/type ]; then + HYPER_TYPE=$(cat /sys/hypervisor/type 2>/dev/null) || true + if [ -n "$HYPER_TYPE" ]; then + IS_VM=true + IS_PHYSICAL=false + VIRT_TYPE="$HYPER_TYPE" + return 0 + fi + fi + + # Method 7: Check dmesg for virtualization hints (requires root usually) + if command -v dmesg &> /dev/null; then + DMESG_OUTPUT=$(dmesg 2>/dev/null | grep -i "hypervisor detected" | head -1) || true + if [ -n "$DMESG_OUTPUT" ]; then + IS_VM=true + IS_PHYSICAL=false + VIRT_TYPE=$(echo "$DMESG_OUTPUT" | sed 's/.*: //') + return 0 + fi + fi + + # Method 8: Check for virt-what output if available + if command -v virt-what &> /dev/null; then + VIRT_WHAT=$(virt-what 2>/dev/null | head -1) || true + if [ -n "$VIRT_WHAT" ]; then + IS_VM=true + IS_PHYSICAL=false + VIRT_TYPE="$VIRT_WHAT" + return 0 + fi + fi + + # Method 9: ARM-specific checks + if [ "$ARCH" = "aarch64" ] || [ "$ARCH" = "armv7l" ] || [ "$ARCH" = "armv8l" ]; then + # Check for device tree model (Raspberry Pi, etc.) + if [ -f /proc/device-tree/model ]; then + DT_MODEL=$(cat /proc/device-tree/model 2>/dev/null | tr -d '\0') || true + if [ -n "$DT_MODEL" ]; then + HW_MODEL="$DT_MODEL" + fi + fi + + # Check for ARM cloud VMs via device tree + if [ -f /proc/device-tree/hypervisor/compatible ]; then + HYPER_COMPAT=$(cat /proc/device-tree/hypervisor/compatible 2>/dev/null | tr -d '\0') || true + if [ -n "$HYPER_COMPAT" ]; then + IS_VM=true + IS_PHYSICAL=false + VIRT_TYPE="$HYPER_COMPAT" + return 0 + fi + fi + + # Check for common ARM cloud platforms + if [ -f /sys/firmware/devicetree/base/compatible ]; then + DT_COMPAT=$(cat /sys/firmware/devicetree/base/compatible 2>/dev/null | tr '\0' ' ') || true + case "$DT_COMPAT" in + *amazon,graviton*|*aws,nitro*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="AWS Graviton (Nitro)"; return 0 ;; + *google,*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Google Cloud ARM"; return 0 ;; + *azure,*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Azure ARM"; return 0 ;; + *oracle,*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Oracle Cloud ARM"; return 0 ;; + *ampere,*) + IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Ampere Cloud"; return 0 ;; + esac + fi + fi + + return 1 +} + +# Run virtualization detection (don't exit on failure) +detect_virtualization || true + +# Get hypervisor info if available +if [ -f /sys/hypervisor/uuid ]; then + HYPERVISOR_UUID=$(cat /sys/hypervisor/uuid 2>/dev/null) || true +fi + +# Display results +if [ "$IS_CONTAINER" = true ]; then + print_info "Environment: Container ($CONTAINER_TYPE)" + if [ "$IS_VM" = true ]; then + print_info "Host appears to be: Virtual Machine ($VIRT_TYPE)" + fi +elif [ "$IS_VM" = true ]; then + print_info "Environment: Virtual Machine" + print_ok "Virtualization: $VIRT_TYPE" + + # Additional VM details + if [ -f /sys/class/dmi/id/product_name ]; then + PRODUCT=$(cat /sys/class/dmi/id/product_name 2>/dev/null) || true + [ -n "$PRODUCT" ] && print_info "Product: $PRODUCT" + fi + if [ -f /sys/class/dmi/id/product_version ]; then + VERSION=$(cat /sys/class/dmi/id/product_version 2>/dev/null) || true + [ -n "$VERSION" ] && [ "$VERSION" != "None" ] && print_info "Version: $VERSION" + fi + if [ -n "$HYPERVISOR_UUID" ]; then + print_info "Hypervisor UUID: $HYPERVISOR_UUID" + fi + + # Check for cloud provider + CLOUD_PROVIDER="" + if [ -f /sys/class/dmi/id/chassis_asset_tag ]; then + CHASSIS_TAG=$(cat /sys/class/dmi/id/chassis_asset_tag 2>/dev/null) || true + case "$CHASSIS_TAG" in + *Amazon*) CLOUD_PROVIDER="AWS" ;; + *Google*) CLOUD_PROVIDER="Google Cloud" ;; + *Azure*) CLOUD_PROVIDER="Microsoft Azure" ;; + esac + fi + if [ -z "$CLOUD_PROVIDER" ] && [ -f /sys/class/dmi/id/board_asset_tag ]; then + BOARD_TAG=$(cat /sys/class/dmi/id/board_asset_tag 2>/dev/null) || true + case "$BOARD_TAG" in + i-*) CLOUD_PROVIDER="AWS (EC2 instance: $BOARD_TAG)" ;; + esac + fi + [ -n "$CLOUD_PROVIDER" ] && print_info "Cloud Provider: $CLOUD_PROVIDER" +else + # Physical machine or unknown + print_ok "Environment: Physical Machine" + + # Show hardware info + HW_INFO_FOUND=false + + # ARM: Check device tree model first + if [ -n "${HW_MODEL:-}" ]; then + print_info "Hardware: $HW_MODEL" + HW_INFO_FOUND=true + elif [ -f /proc/device-tree/model ]; then + DT_MODEL=$(cat /proc/device-tree/model 2>/dev/null | tr -d '\0') || true + if [ -n "$DT_MODEL" ]; then + print_info "Hardware: $DT_MODEL" + HW_INFO_FOUND=true + fi + fi + + # x86: Check DMI info + if [ -f /sys/class/dmi/id/product_name ]; then + PRODUCT=$(cat /sys/class/dmi/id/product_name 2>/dev/null) || true + if [ -n "$PRODUCT" ] && [ "$PRODUCT" != "None" ]; then + print_info "Product: $PRODUCT" + HW_INFO_FOUND=true + fi + fi + if [ -f /sys/class/dmi/id/sys_vendor ]; then + VENDOR=$(cat /sys/class/dmi/id/sys_vendor 2>/dev/null) || true + if [ -n "$VENDOR" ] && [ "$VENDOR" != "None" ]; then + print_info "Vendor: $VENDOR" + HW_INFO_FOUND=true + fi + fi + + # If no hardware info found, show CPU info + if [ "$HW_INFO_FOUND" = false ]; then + if [ -f /proc/cpuinfo ]; then + # ARM: look for Hardware or model name + CPU_MODEL=$(grep -m1 "^model name\|^Hardware\|^Model" /proc/cpuinfo 2>/dev/null | cut -d: -f2 | sed 's/^ //') || true + if [ -n "$CPU_MODEL" ]; then + print_info "CPU: $CPU_MODEL" + fi + fi + fi +fi + +# Check CPU virtualization support +if [ -f /proc/cpuinfo ]; then + # x86/x86_64: Check for VMX/SVM + if grep -q "^flags.*vmx" /proc/cpuinfo 2>/dev/null; then + print_info "CPU Virtualization: Intel VT-x supported" + elif grep -q "^flags.*svm" /proc/cpuinfo 2>/dev/null; then + print_info "CPU Virtualization: AMD-V supported" + # ARM: Check for virtualization extensions + elif [ "$ARCH" = "aarch64" ]; then + # Check CPU features for ARM virtualization + if grep -q "^Features.*:.*" /proc/cpuinfo 2>/dev/null; then + ARM_FEATURES=$(grep -m1 "^Features" /proc/cpuinfo | cut -d: -f2) || true + # ARM VHE (Virtualization Host Extensions) present in ARMv8.1+ + print_info "CPU: ARM64 (hardware virtualization capable)" + fi + fi +fi + +# Note about XDP in VMs +if [ "$IS_VM" = true ]; then + print_warn "Running in VM - XDP hardware offload not available" + print_info " XDP will use driver or SKB mode (still performant)" +fi + +# Distribution-specific notes for moat +case "$DISTRO_ID" in + alpine) + print_warn "Alpine Linux detected - ensure libc6-compat is installed for glibc binaries" + ;; + ubuntu|debian) + if [ "$KERNEL_MAJOR" -lt 5 ]; then + print_info "Consider upgrading to a newer kernel for better eBPF support" + fi + ;; + rhel|centos|rocky|almalinux) + print_info "RHEL-based distribution - ensure kernel-headers and bpftool packages are available" + ;; +esac + +# ============================================================================= +# KERNEL VERSION CHECK +# ============================================================================= +print_section "Kernel Version" + +KERNEL_VERSION=$(uname -r) +KERNEL_MAJOR=$(echo "$KERNEL_VERSION" | cut -d. -f1) +KERNEL_MINOR=$(echo "$KERNEL_VERSION" | cut -d. -f2) + +print_info "Kernel version: $KERNEL_VERSION" + +if version_compare "$KERNEL_MAJOR.$KERNEL_MINOR" "" $MIN_KERNEL_MAJOR $MIN_KERNEL_MINOR; then + print_ok "Kernel version >= ${MIN_KERNEL_MAJOR}.${MIN_KERNEL_MINOR} (required for XDP)" + FEATURES["xdp"]=1 +else + print_fail "Kernel version < ${MIN_KERNEL_MAJOR}.${MIN_KERNEL_MINOR} (XDP requires >= 4.18)" +fi + +# Check for specific kernel features via /boot/config if available +if [ -f "/boot/config-$KERNEL_VERSION" ]; then + print_info "Found kernel config at /boot/config-$KERNEL_VERSION" +elif [ -f "/proc/config.gz" ]; then + print_info "Found kernel config at /proc/config.gz" +fi + +# ============================================================================= +# GLIBC VERSION CHECK +# ============================================================================= +print_section "glibc Version" + +if command -v ldd &> /dev/null; then + GLIBC_VERSION=$(ldd --version 2>&1 | head -n1 | grep -oE '[0-9]+\.[0-9]+' | head -1) + if [ -n "$GLIBC_VERSION" ]; then + print_info "glibc version: $GLIBC_VERSION" + if version_compare "$GLIBC_VERSION" "" $MIN_GLIBC_MAJOR $MIN_GLIBC_MINOR; then + print_ok "glibc version >= ${MIN_GLIBC_MAJOR}.${MIN_GLIBC_MINOR}" + else + print_warn "glibc version < ${MIN_GLIBC_MAJOR}.${MIN_GLIBC_MINOR}" + fi + else + print_warn "Could not determine glibc version" + fi +else + print_warn "ldd not available, cannot check glibc version" +fi + +# Check musl as alternative +if command -v musl-ldd &> /dev/null || [ -f /lib/ld-musl-*.so.1 ]; then + MUSL_VERSION=$(ls /lib/ld-musl-*.so.1 2>/dev/null | head -1) + if [ -n "$MUSL_VERSION" ]; then + print_info "musl libc detected: $MUSL_VERSION" + fi +fi + +# ============================================================================= +# BPF/EBPF SUPPORT CHECK (without bpftool) +# ============================================================================= +print_section "eBPF Support" + +# Check if BPF syscall is available via /proc/kallsyms +if [ -f /proc/kallsyms ]; then + if grep -q "bpf_prog_load" /proc/kallsyms 2>/dev/null || grep -q " bpf_" /proc/kallsyms 2>/dev/null; then + print_ok "BPF syscall symbols found in kernel" + FEATURES["bpf_syscall"]=1 + else + # Try alternative check - /sys/fs/bpf + if [ -d /sys/fs/bpf ]; then + print_ok "BPF filesystem available at /sys/fs/bpf" + FEATURES["bpf_syscall"]=1 + else + print_fail "BPF syscall not available" + fi + fi +else + # Fallback: check /sys/fs/bpf + if [ -d /sys/fs/bpf ]; then + print_ok "BPF filesystem available at /sys/fs/bpf" + FEATURES["bpf_syscall"]=1 + else + print_warn "Cannot verify BPF support (/proc/kallsyms not readable)" + fi +fi + +# Check for BPF JIT +if [ -f /proc/sys/net/core/bpf_jit_enable ]; then + BPF_JIT=$(cat /proc/sys/net/core/bpf_jit_enable 2>/dev/null) + if [ "$BPF_JIT" = "1" ] || [ "$BPF_JIT" = "2" ]; then + print_ok "BPF JIT is enabled (value: $BPF_JIT)" + FEATURES["bpf_jit"]=1 + else + print_warn "BPF JIT is disabled (value: $BPF_JIT)" + print_info " Enable with: echo 1 > /proc/sys/net/core/bpf_jit_enable" + fi +else + print_warn "BPF JIT sysctl not found" +fi + +# Check for BTF (BPF Type Format) support +if [ -f /sys/kernel/btf/vmlinux ]; then + print_ok "BTF (BPF Type Format) is available" + FEATURES["btf"]=1 +else + print_warn "BTF not available (CO-RE features may be limited)" + print_info " BTF requires CONFIG_DEBUG_INFO_BTF=y in kernel config" +fi + +# Check BPF hardening settings +if [ -f /proc/sys/kernel/unprivileged_bpf_disabled ]; then + UNPRIV_BPF=$(cat /proc/sys/kernel/unprivileged_bpf_disabled 2>/dev/null) + case "$UNPRIV_BPF" in + 0) + print_info "Unprivileged BPF: allowed" + ;; + 1) + print_info "Unprivileged BPF: disabled (root required)" + ;; + 2) + print_info "Unprivileged BPF: permanently disabled" + ;; + *) + print_info "Unprivileged BPF setting: $UNPRIV_BPF" + ;; + esac +fi + +# Check BPF JIT hardening +if [ -f /proc/sys/net/core/bpf_jit_harden ]; then + BPF_HARDEN=$(cat /proc/sys/net/core/bpf_jit_harden 2>/dev/null) + case "$BPF_HARDEN" in + 0) + print_info "BPF JIT hardening: disabled" + ;; + 1) + print_info "BPF JIT hardening: enabled for unprivileged" + ;; + 2) + print_info "BPF JIT hardening: enabled for all" + ;; + esac +fi + +# Check memlock limit (critical for BPF maps) +print_section "Memory Limits" + +MEMLOCK_SUFFICIENT=false +MIN_MEMLOCK_KB=65536 # 64MB minimum recommended for BPF + +# Method 1: Check ulimit +if command -v ulimit &> /dev/null; then + MEMLOCK_SOFT=$(ulimit -l 2>/dev/null) || true + MEMLOCK_HARD=$(ulimit -Hl 2>/dev/null) || true + + if [ "$MEMLOCK_SOFT" = "unlimited" ]; then + print_ok "Memlock (soft limit): unlimited" + MEMLOCK_SUFFICIENT=true + elif [ -n "$MEMLOCK_SOFT" ] && [ "$MEMLOCK_SOFT" -ge "$MIN_MEMLOCK_KB" ] 2>/dev/null; then + print_ok "Memlock (soft limit): ${MEMLOCK_SOFT} KB (sufficient)" + MEMLOCK_SUFFICIENT=true + elif [ -n "$MEMLOCK_SOFT" ]; then + print_warn "Memlock (soft limit): ${MEMLOCK_SOFT} KB (may be too low for BPF)" + print_info " Recommended: >= ${MIN_MEMLOCK_KB} KB or unlimited" + fi + + if [ "$MEMLOCK_HARD" = "unlimited" ]; then + print_info "Memlock (hard limit): unlimited" + elif [ -n "$MEMLOCK_HARD" ]; then + print_info "Memlock (hard limit): ${MEMLOCK_HARD} KB" + fi +fi + +# Method 2: Check /proc/self/limits for more detail +if [ -f /proc/self/limits ]; then + LIMITS_LINE=$(grep "Max locked memory" /proc/self/limits 2>/dev/null) || true + if [ -n "$LIMITS_LINE" ]; then + SOFT_BYTES=$(echo "$LIMITS_LINE" | awk '{print $4}') + HARD_BYTES=$(echo "$LIMITS_LINE" | awk '{print $5}') + + if [ "$SOFT_BYTES" = "unlimited" ]; then + print_info "Process memlock: unlimited" + MEMLOCK_SUFFICIENT=true + elif [ -n "$SOFT_BYTES" ] && [ "$SOFT_BYTES" != "unlimited" ]; then + SOFT_MB=$((SOFT_BYTES / 1024 / 1024)) + print_info "Process memlock: ${SOFT_MB} MB (${SOFT_BYTES} bytes)" + fi + fi +fi + +# Method 3: Check systemd default if applicable +if [ -f /etc/systemd/system.conf ]; then + SYSTEMD_MEMLOCK=$(grep -E "^DefaultLimitMEMLOCK=" /etc/systemd/system.conf 2>/dev/null | cut -d= -f2) || true + if [ -n "$SYSTEMD_MEMLOCK" ]; then + print_info "Systemd DefaultLimitMEMLOCK: $SYSTEMD_MEMLOCK" + fi +fi + +# Provide fix instructions if memlock is insufficient +if [ "$MEMLOCK_SUFFICIENT" = false ]; then + print_warn "Memlock limit may be insufficient for BPF operations" + print_info " To fix temporarily: ulimit -l unlimited" + print_info " To fix permanently, add to /etc/security/limits.conf:" + print_info " * soft memlock unlimited" + print_info " * hard memlock unlimited" + print_info " Or for systemd services, add to unit file:" + print_info " LimitMEMLOCK=infinity" +fi + +# ============================================================================= +# KERNEL BPF CONFIG PARAMETERS +# ============================================================================= +print_section "Kernel BPF Configuration" + +check_kernel_config() { + local config_name="$1" + local description="$2" + local config_file="" + + # Find kernel config + if [ -f "/boot/config-$KERNEL_VERSION" ]; then + config_file="/boot/config-$KERNEL_VERSION" + elif [ -f "/proc/config.gz" ]; then + if command -v zcat &> /dev/null; then + if zcat /proc/config.gz 2>/dev/null | grep -q "^${config_name}="; then + print_ok "$config_name ($description)" + return 0 + elif zcat /proc/config.gz 2>/dev/null | grep -q "^# ${config_name} is not set"; then + print_fail "$config_name ($description) - not enabled" + return 1 + fi + fi + return 2 + fi + + if [ -n "$config_file" ]; then + if grep -q "^${config_name}=y" "$config_file" 2>/dev/null; then + print_ok "$config_name ($description)" + return 0 + elif grep -q "^${config_name}=m" "$config_file" 2>/dev/null; then + print_ok "$config_name ($description) [module]" + return 0 + elif grep -q "^# ${config_name} is not set" "$config_file" 2>/dev/null; then + print_fail "$config_name ($description) - not enabled" + return 1 + fi + fi + + print_info "$config_name ($description) - cannot verify" + return 2 +} + +# Required BPF configs +check_kernel_config "CONFIG_BPF" "Basic BPF support" +check_kernel_config "CONFIG_BPF_SYSCALL" "BPF system calls" +check_kernel_config "CONFIG_BPF_JIT" "BPF JIT compiler" +check_kernel_config "CONFIG_HAVE_EBPF_JIT" "eBPF JIT support" + +# XDP specific configs +check_kernel_config "CONFIG_XDP_SOCKETS" "XDP socket support" +check_kernel_config "CONFIG_NET_CLS_BPF" "BPF classifier" +check_kernel_config "CONFIG_NET_ACT_BPF" "BPF action module" + +# BTF and debugging +check_kernel_config "CONFIG_DEBUG_INFO_BTF" "BTF debug info" +check_kernel_config "CONFIG_BPF_LSM" "BPF LSM support" + +# Additional useful configs +check_kernel_config "CONFIG_CGROUP_BPF" "Cgroup BPF support" +check_kernel_config "CONFIG_BPF_STREAM_PARSER" "BPF stream parser" + +# ============================================================================= +# XDP SUPPORT CHECK +# ============================================================================= +print_section "XDP Support" + +# Check for XDP support in network drivers +if [ -d /sys/class/net ]; then + print_info "Available network interfaces:" + for iface in /sys/class/net/*; do + iface_name=$(basename "$iface") + if [ "$iface_name" != "lo" ]; then + driver="" + if [ -L "$iface/device/driver" ]; then + driver=$(basename "$(readlink "$iface/device/driver")") + fi + + # Check for XDP support indicators + xdp_mode="unknown" + if [ -f "$iface/xdp" ] || [ -d "$iface/xdp" ]; then + xdp_mode="supported" + fi + + if [ -n "$driver" ]; then + print_info " $iface_name (driver: $driver, xdp: $xdp_mode)" + else + print_info " $iface_name (xdp: $xdp_mode)" + fi + fi + done +fi + +# Check XDP modes availability (based on kernel version) +if [ "$KERNEL_MAJOR" -ge 5 ] || ([ "$KERNEL_MAJOR" -eq 4 ] && [ "$KERNEL_MINOR" -ge 18 ]); then + print_ok "XDP SKB (generic) mode supported" + FEATURES["xdp_skb"]=1 +fi + +if [ "$KERNEL_MAJOR" -ge 5 ] || ([ "$KERNEL_MAJOR" -eq 4 ] && [ "$KERNEL_MINOR" -ge 18 ]); then + print_ok "XDP driver mode potentially supported (driver dependent)" + FEATURES["xdp_driver"]=1 +fi + +if [ "$KERNEL_MAJOR" -ge 5 ]; then + print_ok "XDP hardware offload potentially supported (NIC dependent)" + FEATURES["xdp_hardware"]=1 +fi + +# ============================================================================= +# NFTABLES CHECK +# ============================================================================= +print_section "nftables Support" + +if command -v nft &> /dev/null; then + NFT_VERSION=$(nft --version 2>&1 | head -1) + print_ok "nft command available: $NFT_VERSION" + FEATURES["nftables"]=1 + + # Check if we can list tables (may need root) + if nft list tables &> /dev/null; then + print_ok "nft can list tables (have permissions)" + + # Check for existing moat/synapse tables + if nft list tables 2>/dev/null | grep -q "synapse"; then + print_info "Found existing 'synapse' table" + fi + else + print_warn "nft cannot list tables (may need root privileges)" + fi +else + print_fail "nft command not found" + print_info " Install with: apt install nftables (Debian/Ubuntu)" + print_info " Or: yum install nftables (RHEL/CentOS)" +fi + +# Check nftables kernel module +if [ -f /proc/modules ]; then + if grep -q "^nf_tables" /proc/modules 2>/dev/null; then + print_ok "nf_tables kernel module loaded" + else + # Module might be built-in + if [ -d /sys/module/nf_tables ]; then + print_ok "nf_tables module available" + else + print_warn "nf_tables module not loaded" + print_info " Load with: modprobe nf_tables" + fi + fi +fi + +# ============================================================================= +# IPTABLES CHECK +# ============================================================================= +print_section "iptables Support" + +# IPv4 iptables +if command -v iptables &> /dev/null; then + IPT_VERSION=$(iptables --version 2>&1) + print_ok "iptables available: $IPT_VERSION" + FEATURES["iptables"]=1 + + # Check if it's nft backend or legacy + if echo "$IPT_VERSION" | grep -q "nf_tables"; then + print_info " Using nftables backend (iptables-nft)" + elif echo "$IPT_VERSION" | grep -q "legacy"; then + print_info " Using legacy backend (iptables-legacy)" + fi + + # Check permissions + if iptables -L -n &> /dev/null; then + print_ok "iptables can list rules (have permissions)" + + # Check for existing SYNAPSE_BLOCK chain + if iptables -L SYNAPSE_BLOCK &> /dev/null; then + print_info "Found existing 'SYNAPSE_BLOCK' chain" + fi + else + print_warn "iptables cannot list rules (may need root privileges)" + fi +else + print_fail "iptables command not found" +fi + +# IPv6 ip6tables +if command -v ip6tables &> /dev/null; then + IP6T_VERSION=$(ip6tables --version 2>&1) + print_ok "ip6tables available: $IP6T_VERSION" + FEATURES["ip6tables"]=1 + + if ip6tables -L -n &> /dev/null; then + print_ok "ip6tables can list rules (have permissions)" + else + print_warn "ip6tables cannot list rules (may need root privileges)" + fi +else + print_warn "ip6tables command not found (IPv6 filtering via iptables unavailable)" +fi + +# ============================================================================= +# IPV6 SUPPORT +# ============================================================================= +print_section "IPv6 Support" + +if [ -f /proc/sys/net/ipv6/conf/all/disable_ipv6 ]; then + IPV6_DISABLED=$(cat /proc/sys/net/ipv6/conf/all/disable_ipv6 2>/dev/null) + if [ "$IPV6_DISABLED" = "0" ]; then + print_ok "IPv6 is enabled system-wide" + FEATURES["ipv6"]=1 + else + print_warn "IPv6 is disabled system-wide" + print_info " Note: XDP may require IPv6 enabled per-interface" + print_info " Synapse can enable IPv6 per-interface automatically" + fi +else + print_warn "Cannot determine IPv6 status" +fi + +# Check per-interface IPv6 +if [ -d /sys/class/net ]; then + ipv6_interfaces="" + for iface in /sys/class/net/*; do + iface_name=$(basename "$iface") + if [ "$iface_name" != "lo" ]; then + if [ -f "/proc/sys/net/ipv6/conf/$iface_name/disable_ipv6" ]; then + ipv6_val=$(cat "/proc/sys/net/ipv6/conf/$iface_name/disable_ipv6" 2>/dev/null) + if [ "$ipv6_val" = "0" ]; then + ipv6_interfaces="$ipv6_interfaces $iface_name(enabled)" + else + ipv6_interfaces="$ipv6_interfaces $iface_name(disabled)" + fi + fi + fi + done + if [ -n "$ipv6_interfaces" ]; then + print_info "Per-interface IPv6:$ipv6_interfaces" + fi +fi + +# ============================================================================= +# CAPABILITIES CHECK +# ============================================================================= +print_section "Required Capabilities" + +print_info "Synapse requires the following capabilities:" +print_info " CAP_SYS_ADMIN - Load/manage BPF programs" +print_info " CAP_NET_ADMIN - Network administration" +print_info " CAP_BPF - BPF operations (kernel >= 5.8)" +print_info " CAP_PERFMON - Performance monitoring" +print_info " CAP_SYS_RESOURCE - Unlimited locked memory" + +# Check if running as root +if [ "$(id -u)" = "0" ]; then + print_ok "Running as root (all capabilities available)" +else + print_warn "Not running as root" + + # Check current capabilities if capsh is available + if command -v capsh &> /dev/null; then + print_info "Current capabilities:" + capsh --print 2>/dev/null | grep "Current" | head -1 || true + fi +fi + +# Check locked memory limit +if command -v ulimit &> /dev/null; then + MEMLOCK=$(ulimit -l 2>/dev/null) + if [ "$MEMLOCK" = "unlimited" ]; then + print_ok "Locked memory limit: unlimited" + else + print_warn "Locked memory limit: ${MEMLOCK}KB" + print_info " Consider: ulimit -l unlimited" + fi +fi + +# ============================================================================= +# SUMMARY +# ============================================================================= +print_header "Feature Summary" + +echo "" +echo -e "${BOLD}Core Features:${NC}" +echo -e " XDP Support: $([ ${FEATURES["xdp"]} -eq 1 ] && echo -e "${GREEN}Available${NC}" || echo -e "${RED}Not Available${NC}")" +echo -e " - Hardware Offload: $([ ${FEATURES["xdp_hardware"]} -eq 1 ] && echo -e "${GREEN}Possible${NC}" || echo -e "${YELLOW}Unlikely${NC}")" +echo -e " - Driver Mode: $([ ${FEATURES["xdp_driver"]} -eq 1 ] && echo -e "${GREEN}Possible${NC}" || echo -e "${YELLOW}Unlikely${NC}")" +echo -e " - SKB/Generic Mode: $([ ${FEATURES["xdp_skb"]} -eq 1 ] && echo -e "${GREEN}Available${NC}" || echo -e "${RED}Not Available${NC}")" +echo -e " BPF Syscall: $([ ${FEATURES["bpf_syscall"]} -eq 1 ] && echo -e "${GREEN}Available${NC}" || echo -e "${RED}Not Available${NC}")" +echo -e " BPF JIT: $([ ${FEATURES["bpf_jit"]} -eq 1 ] && echo -e "${GREEN}Enabled${NC}" || echo -e "${YELLOW}Disabled${NC}")" +echo -e " BTF Support: $([ ${FEATURES["btf"]} -eq 1 ] && echo -e "${GREEN}Available${NC}" || echo -e "${YELLOW}Not Available${NC}")" + +echo "" +echo -e "${BOLD}Fallback Firewalls:${NC}" +echo -e " nftables: $([ ${FEATURES["nftables"]} -eq 1 ] && echo -e "${GREEN}Available${NC}" || echo -e "${RED}Not Available${NC}")" +echo -e " iptables (IPv4): $([ ${FEATURES["iptables"]} -eq 1 ] && echo -e "${GREEN}Available${NC}" || echo -e "${RED}Not Available${NC}")" +echo -e " ip6tables (IPv6): $([ ${FEATURES["ip6tables"]} -eq 1 ] && echo -e "${GREEN}Available${NC}" || echo -e "${RED}Not Available${NC}")" + +echo "" +echo -e "${BOLD}Network:${NC}" +echo -e " IPv6: $([ ${FEATURES["ipv6"]} -eq 1 ] && echo -e "${GREEN}Enabled${NC}" || echo -e "${YELLOW}Disabled${NC}")" + +echo "" +echo -e "${BOLD}Synapse Operational Mode:${NC}" +if [ ${FEATURES["xdp"]} -eq 1 ] && [ ${FEATURES["bpf_syscall"]} -eq 1 ]; then + echo -e " ${GREEN}XDP mode available${NC} - Best performance" + echo -e " Synapse will use XDP for packet filtering with fallback chain:" + echo -e " 1. Hardware offload (if NIC supports)" + echo -e " 2. Driver mode (if driver supports)" + echo -e " 3. SKB/Generic mode (guaranteed fallback)" +elif [ ${FEATURES["nftables"]} -eq 1 ]; then + echo -e " ${YELLOW}nftables fallback mode${NC}" + echo -e " XDP not available, moat will use nftables for filtering" +elif [ ${FEATURES["iptables"]} -eq 1 ]; then + echo -e " ${YELLOW}iptables fallback mode${NC}" + echo -e " XDP and nftables not available, moat will use iptables" +else + echo -e " ${RED}No packet filtering available!${NC}" + echo -e " Synapse cannot operate without XDP, nftables, or iptables" +fi + +echo "" +echo -e "${BOLD}Recommendations:${NC}" + +if [ ${FEATURES["bpf_jit"]} -eq 0 ]; then + echo -e " ${YELLOW}*${NC} Enable BPF JIT for better performance:" + echo -e " echo 1 > /proc/sys/net/core/bpf_jit_enable" +fi + +if [ ${FEATURES["btf"]} -eq 0 ]; then + echo -e " ${YELLOW}*${NC} Rebuild kernel with CONFIG_DEBUG_INFO_BTF=y for CO-RE support" +fi + +if [ ${FEATURES["ipv6"]} -eq 0 ]; then + echo -e " ${YELLOW}*${NC} Consider enabling IPv6 (XDP may require it per-interface)" +fi + +if [ "$(id -u)" != "0" ]; then + echo -e " ${YELLOW}*${NC} Run moat as root or with required capabilities" +fi + +# No recommendations needed +if [ ${FEATURES["xdp"]} -eq 1 ] && [ ${FEATURES["bpf_syscall"]} -eq 1 ] && \ + [ ${FEATURES["bpf_jit"]} -eq 1 ] && [ ${FEATURES["btf"]} -eq 1 ]; then + echo -e " ${GREEN}System is optimally configured for moat!${NC}" +fi + +echo "" +print_header "Diagnostic Complete" diff --git a/pkg/debug/local_build.sh b/pkg/debug/local_build.sh new file mode 100644 index 0000000..5eb720f --- /dev/null +++ b/pkg/debug/local_build.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +set -euo pipefail + +PLATFORM=${1:-"linux/amd64"} +LEGACY=${2:-"false"} +BUILD_FLAGS=${3:-""} + +docker rm -f temp-container + +if [ "$LEGACY" = "true" ]; then + DOCKERFILE="pkg/debug/build-legacy.Dockerfile" +else + DOCKERFILE="pkg/debug/build.Dockerfile" +fi + +BUILD_OPTS=() +[ -n "$PLATFORM" ] && BUILD_OPTS+=(--platform "$PLATFORM") +[ -n "$BUILD_FLAGS" ] && BUILD_OPTS+=(--build-arg "BUILD_FLAGS=$BUILD_FLAGS") +docker buildx build "${BUILD_OPTS[@]}" --load -f $DOCKERFILE -t synapse-builder . +docker create --name temp-container synapse-builder +docker cp temp-container:/output/synapse ./synapse diff --git a/src/agent_status.rs b/src/agent_status.rs new file mode 100644 index 0000000..68fb1f3 --- /dev/null +++ b/src/agent_status.rs @@ -0,0 +1,73 @@ +use std::collections::HashMap; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AgentStatusEvent { + pub event_type: String, + pub schema_version: String, + pub timestamp: DateTime, + pub agent_id: String, + pub agent_name: String, + pub hostname: String, + pub version: String, + pub mode: String, + pub status: String, + pub pid: u32, + pub started_at: DateTime, + pub last_seen: DateTime, + pub uptime_secs: u64, + pub tags: Vec, + pub capabilities: Vec, + pub interfaces: Vec, + pub ip_addresses: Vec, + pub metadata: HashMap, +} + +#[derive(Debug, Clone)] +pub struct AgentStatusIdentity { + pub agent_id: String, + pub agent_name: String, + pub hostname: String, + pub version: String, + pub mode: String, + pub tags: Vec, + pub capabilities: Vec, + pub interfaces: Vec, + pub ip_addresses: Vec, + pub metadata: HashMap, + pub started_at: DateTime, +} + +impl AgentStatusIdentity { + pub fn to_event(&self, status: &str) -> AgentStatusEvent { + let now = Utc::now(); + let uptime_secs = now + .signed_duration_since(self.started_at) + .num_seconds() + .max(0) as u64; + + AgentStatusEvent { + event_type: "agent_status".to_string(), + schema_version: "1.0.0".to_string(), + timestamp: now, + agent_id: self.agent_id.clone(), + agent_name: self.agent_name.clone(), + hostname: self.hostname.clone(), + version: self.version.clone(), + mode: self.mode.clone(), + status: status.to_string(), + pid: std::process::id(), + started_at: self.started_at, + last_seen: now, + uptime_secs, + tags: self.tags.clone(), + capabilities: self.capabilities.clone(), + interfaces: self.interfaces.clone(), + ip_addresses: self.ip_addresses.clone(), + metadata: self.metadata.clone(), + } + } +} + diff --git a/src/bpf_stats_noop.rs b/src/bpf_stats_noop.rs new file mode 100644 index 0000000..db34e09 --- /dev/null +++ b/src/bpf_stats_noop.rs @@ -0,0 +1,236 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use crate::bpf::FilterSkel; + +/// BPF statistics collected from kernel-level access rule enforcement +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BpfAccessStats { + pub timestamp: DateTime, + pub total_packets_processed: u64, + pub total_packets_dropped: u64, + pub ipv4_banned_hits: u64, + pub ipv4_recently_banned_hits: u64, + pub ipv6_banned_hits: u64, + pub ipv6_recently_banned_hits: u64, + pub tcp_fingerprint_blocks_ipv4: u64, + pub tcp_fingerprint_blocks_ipv6: u64, + pub drop_rate_percentage: f64, + pub dropped_ip_addresses: DroppedIpAddresses, +} + +/// Statistics about dropped IP addresses +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DroppedIpAddresses { + pub ipv4_addresses: HashMap, + pub ipv6_addresses: HashMap, + pub total_unique_dropped_ips: u64, +} + +/// Individual event for a dropped IP address +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DroppedIpEvent { + pub event_type: String, + pub timestamp: DateTime, + pub ip_address: String, + pub ip_version: IpVersion, + pub drop_count: u64, + pub drop_reason: DropReason, +} + +/// IP version enumeration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum IpVersion { + IPv4, + IPv6, +} + +/// Reason for dropping packets +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum DropReason { + AccessRules, + RecentlyBannedUdp, + RecentlyBannedIcmp, + RecentlyBannedTcpFinRst, +} + +/// Collection of dropped IP events +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DroppedIpEvents { + pub timestamp: DateTime, + pub events: Vec, + pub total_events: u64, + pub unique_ips: u64, +} + +impl BpfAccessStats { + pub fn empty() -> Self { + Self { + timestamp: Utc::now(), + total_packets_processed: 0, + total_packets_dropped: 0, + ipv4_banned_hits: 0, + ipv4_recently_banned_hits: 0, + ipv6_banned_hits: 0, + ipv6_recently_banned_hits: 0, + tcp_fingerprint_blocks_ipv4: 0, + tcp_fingerprint_blocks_ipv6: 0, + drop_rate_percentage: 0.0, + dropped_ip_addresses: DroppedIpAddresses { + ipv4_addresses: HashMap::new(), + ipv6_addresses: HashMap::new(), + total_unique_dropped_ips: 0, + }, + } + } + + /// Create a summary string for logging + pub fn summary(&self) -> String { + format!( + "BPF Stats: {} packets processed, {} dropped ({:.2}%), {} unique IPs dropped", + self.total_packets_processed, + self.total_packets_dropped, + self.drop_rate_percentage, + self.dropped_ip_addresses.total_unique_dropped_ips + ) + } +} + +impl DroppedIpEvent { + pub fn new( + ip_address: String, + ip_version: IpVersion, + drop_count: u64, + drop_reason: DropReason, + ) -> Self { + Self { + event_type: "dropped_ips".to_string(), + timestamp: Utc::now(), + ip_address, + ip_version, + drop_count, + drop_reason, + } + } + + pub fn to_json(&self) -> Result { + serde_json::to_string(self) + } + + pub fn summary(&self) -> String { + format!( + "IP Drop Event: {} {:?} dropped {} times (reason: {:?})", + self.ip_address, + self.ip_version, + self.drop_count, + self.drop_reason + ) + } +} + +impl DroppedIpEvents { + pub fn new() -> Self { + Self { + timestamp: Utc::now(), + events: Vec::new(), + total_events: 0, + unique_ips: 0, + } + } + + pub fn add_event(&mut self, event: DroppedIpEvent) { + self.events.push(event); + self.total_events += 1; + } + + pub fn to_json(&self) -> Result { + serde_json::to_string(self) + } + + pub fn summary(&self) -> String { + format!( + "Dropped IP Events: {} events from {} unique IPs", + self.total_events, + self.unique_ips + ) + } + + pub fn get_top_dropped_ips(&self, limit: usize) -> Vec { + let mut events = self.events.clone(); + events.sort_by(|a, b| b.drop_count.cmp(&a.drop_count)); + events.into_iter().take(limit).collect() + } +} + +/// Statistics collector for BPF access rules +#[derive(Clone)] +pub struct BpfStatsCollector { + _skels: Vec>>, + enabled: bool, +} + +impl BpfStatsCollector { + pub fn new(skels: Vec>>, enabled: bool) -> Self { + Self { _skels: skels, enabled } + } + + pub fn set_enabled(&mut self, enabled: bool) { + self.enabled = enabled; + } + + pub fn is_enabled(&self) -> bool { + self.enabled + } + + pub fn collect_stats(&self) -> Result, Box> { + Ok(Vec::new()) + } + + pub fn collect_aggregated_stats(&self) -> Result> { + Ok(BpfAccessStats::empty()) + } + + pub fn log_stats(&self) -> Result<(), Box> { + Ok(()) + } + + pub fn collect_dropped_ip_events(&self) -> Result> { + Ok(DroppedIpEvents::new()) + } + + pub fn log_dropped_ip_events(&self) -> Result<(), Box> { + Ok(()) + } + + pub fn reset_dropped_ip_counters(&self) -> Result<(), Box> { + Ok(()) + } +} + +/// Configuration for BPF statistics collection +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BpfStatsConfig { + pub enabled: bool, + pub log_interval_secs: u64, +} + +impl Default for BpfStatsConfig { + fn default() -> Self { + Self { + enabled: true, + log_interval_secs: 60, + } + } +} + +impl BpfStatsConfig { + pub fn new(enabled: bool, log_interval_secs: u64) -> Self { + Self { + enabled, + log_interval_secs, + } + } +} diff --git a/src/bpf_stub.rs b/src/bpf_stub.rs new file mode 100644 index 0000000..6957e7d --- /dev/null +++ b/src/bpf_stub.rs @@ -0,0 +1,29 @@ +use std::marker::PhantomData; + +#[derive(Clone, Debug)] +pub struct FilterSkel<'a> { + _marker: PhantomData<&'a ()>, +} + +impl<'a> FilterSkel<'a> { + pub fn new() -> Self { + Self { _marker: PhantomData } + } +} + +#[derive(Clone, Debug, Default)] +pub struct FilterSkelBuilder; + +impl FilterSkelBuilder { + pub fn open(&self) -> Result> { + Err("BPF support disabled at build time".into()) + } +} + +pub struct FilterSkelOpen; + +impl FilterSkelOpen { + pub fn load(self) -> Result, Box> { + Err("BPF support disabled at build time".into()) + } +} diff --git a/src/firewall_noop.rs b/src/firewall_noop.rs new file mode 100644 index 0000000..981cae6 --- /dev/null +++ b/src/firewall_noop.rs @@ -0,0 +1,290 @@ +use std::error::Error; +use std::marker::PhantomData; +use std::net::{Ipv4Addr, Ipv6Addr}; + +use serde::{Deserialize, Serialize}; + +/// Enum to represent the active firewall backend +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FirewallBackend { + Xdp, + Nftables, + Iptables, + None, +} + +impl std::fmt::Display for FirewallBackend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FirewallBackend::Xdp => write!(f, "XDP/BPF"), + FirewallBackend::Nftables => write!(f, "nftables"), + FirewallBackend::Iptables => write!(f, "iptables"), + FirewallBackend::None => write!(f, "none (userland)"), + } + } +} + +/// Configuration option for forcing a specific firewall backend +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum FirewallMode { + #[default] + Auto, + Xdp, + Nftables, + Iptables, + None, +} + +impl std::fmt::Display for FirewallMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FirewallMode::Auto => write!(f, "auto"), + FirewallMode::Xdp => write!(f, "xdp"), + FirewallMode::Nftables => write!(f, "nftables"), + FirewallMode::Iptables => write!(f, "iptables"), + FirewallMode::None => write!(f, "none"), + } + } +} + +pub trait Firewall { + fn ban_ip_with_notice(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box>; + fn ban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box>; + fn unban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box>; + fn check_if_notice(&mut self, ip: Ipv4Addr) -> Result>; + + // IPv6 methods + fn ban_ipv6_with_notice(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box>; + fn ban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box>; + fn unban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box>; + fn check_if_notice_ipv6(&mut self, ip: Ipv6Addr) -> Result>; + + // TCP fingerprint blocking methods + fn block_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box>; + fn unblock_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box>; + fn block_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box>; + fn unblock_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box>; + fn is_tcp_fingerprint_blocked(&self, fingerprint: &str) -> Result>; + fn is_tcp_fingerprint_blocked_v6(&self, fingerprint: &str) -> Result>; +} + +pub struct SYNAPSEFirewall<'a> { + _skel: &'a crate::bpf::FilterSkel<'a>, +} + +impl<'a> SYNAPSEFirewall<'a> { + pub fn new(skel: &'a crate::bpf::FilterSkel<'a>) -> Self { + Self { _skel: skel } + } +} + +impl<'a> Firewall for SYNAPSEFirewall<'a> { + fn ban_ip_with_notice(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn ban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn unban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn check_if_notice(&mut self, _ip: Ipv4Addr) -> Result> { + Ok(false) + } + + fn ban_ipv6_with_notice(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn ban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn unban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn check_if_notice_ipv6(&mut self, _ip: Ipv6Addr) -> Result> { + Ok(false) + } + + fn block_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn unblock_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn block_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn unblock_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn is_tcp_fingerprint_blocked(&self, _fingerprint: &str) -> Result> { + Ok(false) + } + + fn is_tcp_fingerprint_blocked_v6(&self, _fingerprint: &str) -> Result> { + Ok(false) + } +} + +pub struct NftablesFirewall { + _marker: PhantomData<()>, +} + +impl NftablesFirewall { + pub fn new() -> Result> { + Ok(Self { _marker: PhantomData }) + } + + pub fn is_available() -> bool { + false + } + + pub fn cleanup(&self) -> Result<(), Box> { + Ok(()) + } +} + +impl Firewall for NftablesFirewall { + fn ban_ip_with_notice(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn ban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn unban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn check_if_notice(&mut self, _ip: Ipv4Addr) -> Result> { + Ok(false) + } + + fn ban_ipv6_with_notice(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn ban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn unban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn check_if_notice_ipv6(&mut self, _ip: Ipv6Addr) -> Result> { + Ok(false) + } + + fn block_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn unblock_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn block_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn unblock_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn is_tcp_fingerprint_blocked(&self, _fingerprint: &str) -> Result> { + Ok(false) + } + + fn is_tcp_fingerprint_blocked_v6(&self, _fingerprint: &str) -> Result> { + Ok(false) + } +} + +pub struct IptablesFirewall { + _marker: PhantomData<()>, +} + +impl IptablesFirewall { + pub fn new() -> Result> { + Ok(Self { _marker: PhantomData }) + } + + pub fn is_available() -> bool { + false + } + + pub fn cleanup(&self) -> Result<(), Box> { + Ok(()) + } +} + +impl Firewall for IptablesFirewall { + fn ban_ip_with_notice(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn ban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn unban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn check_if_notice(&mut self, _ip: Ipv4Addr) -> Result> { + Ok(false) + } + + fn ban_ipv6_with_notice(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn ban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn unban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn check_if_notice_ipv6(&mut self, _ip: Ipv6Addr) -> Result> { + Ok(false) + } + + fn block_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn unblock_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn block_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn unblock_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn is_tcp_fingerprint_blocked(&self, _fingerprint: &str) -> Result> { + Ok(false) + } + + fn is_tcp_fingerprint_blocked_v6(&self, _fingerprint: &str) -> Result> { + Ok(false) + } +} diff --git a/src/main.rs b/src/main.rs index b4efd98..f9dc58b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::mem::MaybeUninit; use std::sync::Arc; use std::str::FromStr; @@ -8,12 +9,14 @@ use clap::Parser; use daemonize::Daemonize; use libbpf_rs::skel::{OpenSkel, SkelBuilder}; use nix::net::if_::if_nametoindex; +use chrono::Utc; #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; pub mod access_log; pub mod access_rules; +pub mod agent_status; pub mod app_state; pub mod captcha_server; pub mod cli; @@ -52,6 +55,8 @@ use crate::utils::bpf_utils::{bpf_attach_to_xdp, bpf_detach_from_xdp}; use crate::access_log::LogSenderConfig; use crate::worker::log::set_log_sender_config; +use crate::worker::agent_status::AgentStatusWorker; +use crate::agent_status::AgentStatusIdentity; use crate::authcheck::validate_api_key; use crate::http_client::init_global_client; use crate::waf::actions::captcha::{CaptchaConfig, CaptchaProvider, init_captcha_client, start_cache_cleanup_task}; @@ -222,6 +227,7 @@ fn main() -> Result<()> { #[allow(clippy::too_many_lines)] async fn async_main(args: Args, config: Config) -> Result<()> { + let started_at = Utc::now(); if config.daemon.enabled { log::info!("Running in daemon mode (PID file: {})", config.daemon.pid_file); @@ -854,6 +860,124 @@ async fn async_main(args: Args, config: Config) -> Result<()> { } } + // Register agent status worker (register + heartbeat) if unified event sending is enabled + if log_sender_enabled { + let hostname = std::env::var("HOSTNAME") + .ok() + .filter(|value| !value.trim().is_empty()) + .unwrap_or_else(|| gethostname::gethostname().to_string_lossy().into_owned()); + + let agent_id = std::env::var("AGENT_ID") + .ok() + .filter(|value| !value.trim().is_empty()) + .unwrap_or_else(|| hostname.clone()); + + let agent_name = std::env::var("AGENT_NAME") + .ok() + .filter(|value| !value.trim().is_empty()) + .unwrap_or_else(|| hostname.clone()); + + let tags = std::env::var("AGENT_TAGS") + .ok() + .map(|value| { + value + .split(',') + .map(|tag| tag.trim().to_string()) + .filter(|tag| !tag.is_empty()) + .collect::>() + }) + .unwrap_or_default(); + + let mut capabilities = Vec::new(); + if log_sender_enabled { + capabilities.push("log_sender".to_string()); + } + if config.bpf_stats.enabled { + capabilities.push("bpf_stats".to_string()); + } + if config.bpf_stats.enable_dropped_ip_events { + capabilities.push("bpf_stats_dropped_ip_events".to_string()); + } + if config.tcp_fingerprint.enabled { + capabilities.push("tcp_fingerprint".to_string()); + } + if config.tcp_fingerprint.enable_fingerprint_events { + capabilities.push("tcp_fingerprint_events".to_string()); + } + if content_scanner_enabled { + capabilities.push("content_scanner".to_string()); + } + if waf_enabled { + capabilities.push("waf".to_string()); + } + if threat_client_enabled { + capabilities.push("threat_client".to_string()); + } + if captcha_client_enabled { + capabilities.push("captcha_client".to_string()); + } + if !config.network.disable_xdp { + capabilities.push("xdp".to_string()); + } + + let interfaces = if !config.network.ifaces.is_empty() { + config.network.ifaces.clone() + } else if !config.network.iface.is_empty() { + vec![config.network.iface.clone()] + } else { + Vec::new() + }; + + let ip_addresses = std::env::var("AGENT_IPS") + .or_else(|_| std::env::var("AGENT_IP_ADDRESSES")) + .ok() + .map(|value| { + value + .split(',') + .map(|ip| ip.trim().to_string()) + .filter(|ip| !ip.is_empty()) + .collect::>() + }) + .unwrap_or_default(); + + let mut metadata = HashMap::new(); + metadata.insert("os".to_string(), std::env::consts::OS.to_string()); + metadata.insert("arch".to_string(), std::env::consts::ARCH.to_string()); + metadata.insert("version".to_string(), env!("CARGO_PKG_VERSION").to_string()); + metadata.insert("mode".to_string(), config.mode.clone()); + metadata.insert("platform_base_url".to_string(), config.platform.base_url.clone()); + + let identity = AgentStatusIdentity { + agent_id, + agent_name, + hostname, + version: env!("CARGO_PKG_VERSION").to_string(), + mode: config.mode.clone(), + tags, + capabilities, + interfaces, + ip_addresses, + metadata, + started_at, + }; + + let heartbeat_secs = std::env::var("AGENT_HEARTBEAT_SECS") + .ok() + .and_then(|value| value.parse::().ok()) + .unwrap_or(30); + + let worker_config = worker::WorkerConfig { + name: "agent_status".to_string(), + interval_secs: heartbeat_secs, + enabled: true, + }; + + let agent_status_worker = AgentStatusWorker::new(identity, heartbeat_secs); + if let Err(e) = worker_manager.register_worker(worker_config, agent_status_worker) { + log::error!("Failed to register agent status worker: {}", e); + } + } + // Access rules were already initialized after XDP attachment above // Register config worker to fetch and apply configuration periodically diff --git a/src/utils/bpf_utils_noop.rs b/src/utils/bpf_utils_noop.rs new file mode 100644 index 0000000..ba4d863 --- /dev/null +++ b/src/utils/bpf_utils_noop.rs @@ -0,0 +1,16 @@ +use std::error::Error; + +use crate::bpf::FilterSkel; + +pub fn bpf_attach_to_xdp( + _skel: &mut FilterSkel<'_>, + _ifindex: i32, + _iface_name: Option<&str>, + _ip_version: &str, +) -> Result<(), Box> { + Err("BPF support disabled at build time".into()) +} + +pub fn bpf_detach_from_xdp(_ifindex: i32) -> Result<(), Box> { + Ok(()) +} diff --git a/src/utils/tcp_fingerprint_noop.rs b/src/utils/tcp_fingerprint_noop.rs new file mode 100644 index 0000000..5a84b34 --- /dev/null +++ b/src/utils/tcp_fingerprint_noop.rs @@ -0,0 +1,166 @@ +use std::net::IpAddr; +use std::sync::Arc; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use crate::bpf::FilterSkel; + +/// TCP fingerprinting configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintConfig { + pub enabled: bool, + pub log_interval_secs: u64, + pub enable_fingerprint_events: bool, + pub fingerprint_events_interval_secs: u64, + pub min_packet_count: u32, + pub min_connection_duration_secs: u64, +} + +impl Default for TcpFingerprintConfig { + fn default() -> Self { + Self { + enabled: false, + log_interval_secs: 60, + enable_fingerprint_events: false, + fingerprint_events_interval_secs: 30, + min_packet_count: 3, + min_connection_duration_secs: 1, + } + } +} + +impl TcpFingerprintConfig { + pub fn from_cli_config(cli_config: &crate::cli::TcpFingerprintConfig) -> Self { + Self { + enabled: cli_config.enabled, + log_interval_secs: cli_config.log_interval_secs, + enable_fingerprint_events: cli_config.enable_fingerprint_events, + fingerprint_events_interval_secs: cli_config.fingerprint_events_interval_secs, + min_packet_count: cli_config.min_packet_count, + min_connection_duration_secs: cli_config.min_connection_duration_secs, + } + } +} + +/// TCP fingerprint data collected from BPF +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintData { + pub first_seen: DateTime, + pub last_seen: DateTime, + pub packet_count: u32, + pub ttl: u16, + pub mss: u16, + pub window_size: u16, + pub window_scale: u8, + pub options_len: u8, + pub options: Vec, +} + +/// TCP fingerprint event for API +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintEvent { + pub event_type: String, + pub timestamp: DateTime, + pub src_ip: String, + pub src_port: u16, + pub fingerprint: String, + pub ttl: u16, + pub mss: u16, + pub window_size: u16, + pub window_scale: u8, + pub packet_count: u32, +} + +impl TcpFingerprintEvent { + pub fn summary(&self) -> String { + format!( + "TCP fingerprint event: {}:{} {}", + self.src_ip, self.src_port, self.fingerprint + ) + } +} + +/// Collection of TCP fingerprint events +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintEvents { + pub events: Vec, + pub total_events: u64, + pub unique_ips: u64, +} + +impl TcpFingerprintEvents { + pub fn new() -> Self { + Self { + events: Vec::new(), + total_events: 0, + unique_ips: 0, + } + } + + pub fn summary(&self) -> String { + format!( + "TCP fingerprint events: {} events from {} unique IPs", + self.total_events, + self.unique_ips + ) + } + + pub fn get_top_fingerprints(&self, limit: usize) -> Vec { + let mut events = self.events.clone(); + events.sort_by(|a, b| b.packet_count.cmp(&a.packet_count)); + events.into_iter().take(limit).collect() + } +} + +static TCP_FINGERPRINT_COLLECTOR: std::sync::OnceLock> = std::sync::OnceLock::new(); + +pub fn set_global_tcp_fingerprint_collector(collector: TcpFingerprintCollector) { + let _ = TCP_FINGERPRINT_COLLECTOR.set(Arc::new(collector)); +} + +pub fn get_global_tcp_fingerprint_collector() -> Option> { + TCP_FINGERPRINT_COLLECTOR.get().cloned() +} + +#[derive(Clone)] +pub struct TcpFingerprintCollector { + enabled: bool, + _skels: Vec>>, +} + +impl TcpFingerprintCollector { + pub fn new(skels: Vec>>, enabled: bool) -> Self { + Self { enabled, _skels: skels } + } + + pub fn new_with_config(skels: Vec>>, config: TcpFingerprintConfig) -> Self { + Self { + enabled: config.enabled, + _skels: skels, + } + } + + pub fn lookup_fingerprint(&self, _src_ip: IpAddr, _src_port: u16) -> Option { + None + } + + pub fn collect_fingerprint_events(&self) -> Result> { + Ok(TcpFingerprintEvents::new()) + } + + pub fn log_stats(&self) -> Result<(), Box> { + if self.enabled { + log::debug!("TCP fingerprint stats disabled (BPF support not built)"); + } + Ok(()) + } + + pub fn log_fingerprint_events(&self) -> Result<(), Box> { + Ok(()) + } + + pub fn log_events(&self) -> Result<(), Box> { + Ok(()) + } +} diff --git a/src/worker/agent_status.rs b/src/worker/agent_status.rs new file mode 100644 index 0000000..19b209b --- /dev/null +++ b/src/worker/agent_status.rs @@ -0,0 +1,57 @@ +use std::time::Duration; + +use tokio::sync::watch; +use tokio::time::interval; + +use crate::agent_status::AgentStatusIdentity; +use crate::worker::log::{send_event, UnifiedEvent}; + +/// Agent status worker that sends register + heartbeat events +pub struct AgentStatusWorker { + identity: AgentStatusIdentity, + interval_secs: u64, +} + +impl AgentStatusWorker { + pub fn new(identity: AgentStatusIdentity, interval_secs: u64) -> Self { + Self { + identity, + interval_secs, + } + } +} + +impl super::Worker for AgentStatusWorker { + fn name(&self) -> &str { + "agent_status" + } + + fn run(&self, mut shutdown: watch::Receiver) -> tokio::task::JoinHandle<()> { + let identity = self.identity.clone(); + let interval_secs = self.interval_secs; + let worker_name = self.name().to_string(); + + tokio::spawn(async move { + let mut tick = interval(Duration::from_secs(interval_secs)); + tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + // Initial register event + send_event(UnifiedEvent::AgentStatus(identity.to_event("running"))); + + loop { + tokio::select! { + _ = shutdown.changed() => { + if *shutdown.borrow() { + log::info!("[{}] Shutdown signal received, stopping agent status worker", worker_name); + break; + } + } + _ = tick.tick() => { + send_event(UnifiedEvent::AgentStatus(identity.to_event("running"))); + } + } + } + }) + } +} + diff --git a/src/worker/log.rs b/src/worker/log.rs index 0aaf476..d30cbed 100644 --- a/src/worker/log.rs +++ b/src/worker/log.rs @@ -74,6 +74,8 @@ pub enum UnifiedEvent { DroppedIp(crate::bpf_stats::DroppedIpEvent), #[serde(rename = "tcp_fingerprint")] TcpFingerprint(crate::utils::tcp_fingerprint::TcpFingerprintEvent), + #[serde(rename = "agent_status")] + AgentStatus(crate::agent_status::AgentStatusEvent), } impl UnifiedEvent { @@ -83,6 +85,7 @@ impl UnifiedEvent { UnifiedEvent::HttpAccessLog(_) => "http_access_log", UnifiedEvent::DroppedIp(_) => "dropped_ip", UnifiedEvent::TcpFingerprint(_) => "tcp_fingerprint", + UnifiedEvent::AgentStatus(_) => "agent_status", } } @@ -92,6 +95,7 @@ impl UnifiedEvent { UnifiedEvent::HttpAccessLog(event) => event.timestamp, UnifiedEvent::DroppedIp(event) => event.timestamp, UnifiedEvent::TcpFingerprint(event) => event.timestamp, + UnifiedEvent::AgentStatus(event) => event.timestamp, } } @@ -204,6 +208,7 @@ fn estimate_event_size(event: &UnifiedEvent) -> usize { } UnifiedEvent::DroppedIp(_) => base_size + 200, // Dropped IP events are relatively small UnifiedEvent::TcpFingerprint(_) => base_size + 100, // TCP fingerprint events are small + UnifiedEvent::AgentStatus(_) => base_size + 300, // Agent status events are small } } diff --git a/src/worker/mod.rs b/src/worker/mod.rs index 85def86..def7bc9 100644 --- a/src/worker/mod.rs +++ b/src/worker/mod.rs @@ -6,3 +6,4 @@ pub mod manager; pub mod threat_mmdb; pub use manager::{Worker, WorkerConfig, WorkerManager}; +pub mod agent_status; From 3555fe0ef36f769081708f76026a9afa0b3ff002 Mon Sep 17 00:00:00 2001 From: krichard1212 <136473183+krichard1212@users.noreply.github.com> Date: Tue, 27 Jan 2026 11:55:38 +0100 Subject: [PATCH 02/14] cleanup --- pkg/debug/.dockerignore | 10 - pkg/debug/build-legacy.Dockerfile | 62 -- pkg/debug/build.Dockerfile | 59 -- pkg/debug/diagnostic.sh | 1038 ----------------------------- pkg/debug/local_build.sh | 22 - 5 files changed, 1191 deletions(-) delete mode 100644 pkg/debug/.dockerignore delete mode 100644 pkg/debug/build-legacy.Dockerfile delete mode 100644 pkg/debug/build.Dockerfile delete mode 100644 pkg/debug/diagnostic.sh delete mode 100644 pkg/debug/local_build.sh diff --git a/pkg/debug/.dockerignore b/pkg/debug/.dockerignore deleted file mode 100644 index dd3f5fa..0000000 --- a/pkg/debug/.dockerignore +++ /dev/null @@ -1,10 +0,0 @@ -target -.git -.vscode -helm -docs -.github -.devcontainer -images -*.md -docker/volumes diff --git a/pkg/debug/build-legacy.Dockerfile b/pkg/debug/build-legacy.Dockerfile deleted file mode 100644 index ce004f4..0000000 --- a/pkg/debug/build-legacy.Dockerfile +++ /dev/null @@ -1,62 +0,0 @@ -ARG IMAGE="ubuntu" -ARG IMAGE_TAG="16.04" -ARG BUILD_FLAGS - -FROM ${IMAGE}:${IMAGE_TAG} - -RUN sed -i '/updates/d' /etc/apt/sources.list && \ - sed -i 's/httpredir/archive/' /etc/apt/sources.list && \ - sed -i 's|https://|http://|g' /etc/apt/sources.list - -RUN apt-get update && \ - apt-get install -y --no-install-recommends \ - apt-transport-https \ - ca-certificates \ - curl \ - gnupg \ - lsb-release && \ - curl -fsSL https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - && \ - echo "deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-10 main" >> /etc/apt/sources.list.d/llvm.list && \ - apt-get update && \ - apt-get install -y --no-install-recommends \ - libc6 \ - libc6-dev \ - g++ \ - gcc \ - make && \ - apt-get install -y --no-install-recommends \ - git \ - build-essential \ - clang-10 \ - llvm-10 \ - libelf-dev \ - libelf1 \ - libssl-dev \ - zlib1g-dev \ - libzstd-dev \ - pkg-config \ - libcap-dev \ - binutils-multiarch-dev \ - cmake && \ - update-alternatives --install /usr/bin/clang clang /usr/bin/clang-10 100 && \ - update-alternatives --install /usr/bin/llc llc /usr/bin/llc-10 100 && \ - rm -rf /var/lib/apt/lists/* - -RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \ - . "$HOME/.cargo/env" && \ - rustup default stable && \ - rustup update stable -ENV PATH="/root/.cargo/bin:${PATH}" - -WORKDIR /app - -COPY . . - -# If BUILD_FLAGS is unset or empty, default to --no-default-features (no eBPF). -RUN cargo build --release ${BUILD_FLAGS:---no-default-features} - -# Create output directory and copy binary -RUN mkdir -p /output && \ - cp target/release/synapse /output/synapse - -VOLUME ["/output"] diff --git a/pkg/debug/build.Dockerfile b/pkg/debug/build.Dockerfile deleted file mode 100644 index 23d8249..0000000 --- a/pkg/debug/build.Dockerfile +++ /dev/null @@ -1,59 +0,0 @@ -ARG IMAGE="ubuntu" -ARG IMAGE_TAG="18.04" -ARG BUILD_FLAGS="" - -FROM ${IMAGE}:${IMAGE_TAG} - -RUN sed -i '/updates/d' /etc/apt/sources.list && \ - sed -i 's/httpredir/archive/' /etc/apt/sources.list - -RUN apt-get update && \ - apt-get install -y --no-install-recommends \ - ca-certificates \ - curl \ - gnupg \ - lsb-release && \ - curl -fsSL https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - && \ - echo "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-10 main" >> /etc/apt/sources.list.d/llvm.list && \ - apt-get update && \ - apt-get install -y --no-install-recommends --allow-downgrades \ - libc6=2.27-3ubuntu1.5 \ - libc6-dev \ - g++ \ - gcc \ - make && \ - apt-get install -y --no-install-recommends \ - git \ - build-essential \ - clang-10 \ - llvm-10 \ - libelf-dev \ - libelf1 \ - libssl-dev \ - zlib1g-dev \ - libzstd-dev \ - pkg-config \ - libcap-dev \ - binutils-multiarch-dev \ - cmake && \ - update-alternatives --install /usr/bin/clang clang /usr/bin/clang-10 100 && \ - update-alternatives --install /usr/bin/llc llc /usr/bin/llc-10 100 && \ - rm -rf /var/lib/apt/lists/* - -RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \ - . "$HOME/.cargo/env" && \ - rustup default stable && \ - rustup update stable -ENV PATH="/root/.cargo/bin:${PATH}" - -WORKDIR /app - -COPY . . - -RUN cargo build --release $BUILD_FLAGS - -# Create output directory and copy binary -RUN mkdir -p /output && \ - cp target/release/synapse /output/synapse - -VOLUME ["/output"] diff --git a/pkg/debug/diagnostic.sh b/pkg/debug/diagnostic.sh deleted file mode 100644 index 1d3b7ff..0000000 --- a/pkg/debug/diagnostic.sh +++ /dev/null @@ -1,1038 +0,0 @@ -#!/bin/bash -# -# Synapse Diagnostic Script -# Checks system capabilities for running moat with all its features -# - -set -e - -# Colors for output -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -CYAN='\033[0;36m' -NC='\033[0m' # No Color -BOLD='\033[1m' - -# Feature tracking -declare -A FEATURES -FEATURES["xdp"]=0 -FEATURES["xdp_hardware"]=0 -FEATURES["xdp_driver"]=0 -FEATURES["xdp_skb"]=0 -FEATURES["nftables"]=0 -FEATURES["iptables"]=0 -FEATURES["ip6tables"]=0 -FEATURES["bpf_syscall"]=0 -FEATURES["bpf_jit"]=0 -FEATURES["btf"]=0 -FEATURES["ipv6"]=0 - -# Minimum requirements -MIN_KERNEL_MAJOR=4 -MIN_KERNEL_MINOR=18 -MIN_GLIBC_MAJOR=2 -MIN_GLIBC_MINOR=17 - -print_header() { - echo "" - echo -e "${BOLD}${BLUE}========================================${NC}" - echo -e "${BOLD}${BLUE} $1${NC}" - echo -e "${BOLD}${BLUE}========================================${NC}" -} - -print_section() { - echo "" - echo -e "${CYAN}--- $1 ---${NC}" -} - -print_ok() { - echo -e " ${GREEN}[OK]${NC} $1" -} - -print_warn() { - echo -e " ${YELLOW}[WARN]${NC} $1" -} - -print_fail() { - echo -e " ${RED}[FAIL]${NC} $1" -} - -print_info() { - echo -e " ${BLUE}[INFO]${NC} $1" -} - -version_compare() { - # Returns 0 if $1 >= $2 - local v1_major v1_minor v2_major v2_minor - v1_major=$(echo "$1" | cut -d. -f1) - v1_minor=$(echo "$1" | cut -d. -f2) - v2_major=$3 - v2_minor=$4 - - if [ "$v1_major" -gt "$v2_major" ]; then - return 0 - elif [ "$v1_major" -eq "$v2_major" ] && [ "$v1_minor" -ge "$v2_minor" ]; then - return 0 - fi - return 1 -} - -print_header "Synapse System Diagnostic" -echo -e "Running diagnostics at: $(date)" -echo -e "Hostname: $(hostname)" - -# ============================================================================= -# LINUX DISTRIBUTION CHECK -# ============================================================================= -print_section "Linux Distribution" - -DISTRO_NAME="Unknown" -DISTRO_VERSION="Unknown" -DISTRO_ID="unknown" - -# Try /etc/os-release first (most modern distributions) -if [ -f /etc/os-release ]; then - . /etc/os-release - DISTRO_NAME="${NAME:-Unknown}" - DISTRO_VERSION="${VERSION:-${VERSION_ID:-Unknown}}" - DISTRO_ID="${ID:-unknown}" - DISTRO_ID_LIKE="${ID_LIKE:-}" - DISTRO_PRETTY="${PRETTY_NAME:-$DISTRO_NAME $DISTRO_VERSION}" - - print_info "Distribution: $DISTRO_PRETTY" - if [ -n "$DISTRO_ID_LIKE" ]; then - print_info "Based on: $DISTRO_ID_LIKE" - fi -# Try /etc/lsb-release (older Ubuntu/Debian) -elif [ -f /etc/lsb-release ]; then - . /etc/lsb-release - DISTRO_NAME="${DISTRIB_ID:-Unknown}" - DISTRO_VERSION="${DISTRIB_RELEASE:-Unknown}" - DISTRO_ID=$(echo "$DISTRO_NAME" | tr '[:upper:]' '[:lower:]') - - print_info "Distribution: $DISTRO_NAME $DISTRO_VERSION" - if [ -n "${DISTRIB_CODENAME:-}" ]; then - print_info "Codename: $DISTRIB_CODENAME" - fi -# Try specific release files -elif [ -f /etc/redhat-release ]; then - DISTRO_PRETTY=$(cat /etc/redhat-release) - DISTRO_ID="rhel" - print_info "Distribution: $DISTRO_PRETTY" -elif [ -f /etc/debian_version ]; then - DISTRO_VERSION=$(cat /etc/debian_version) - DISTRO_NAME="Debian" - DISTRO_ID="debian" - print_info "Distribution: Debian $DISTRO_VERSION" -elif [ -f /etc/alpine-release ]; then - DISTRO_VERSION=$(cat /etc/alpine-release) - DISTRO_NAME="Alpine Linux" - DISTRO_ID="alpine" - print_info "Distribution: Alpine Linux $DISTRO_VERSION" -elif [ -f /etc/arch-release ]; then - DISTRO_NAME="Arch Linux" - DISTRO_ID="arch" - print_info "Distribution: Arch Linux (rolling release)" -elif [ -f /etc/gentoo-release ]; then - DISTRO_PRETTY=$(cat /etc/gentoo-release) - DISTRO_ID="gentoo" - print_info "Distribution: $DISTRO_PRETTY" -elif [ -f /etc/SuSE-release ]; then - DISTRO_PRETTY=$(head -1 /etc/SuSE-release) - DISTRO_ID="suse" - print_info "Distribution: $DISTRO_PRETTY" -else - print_warn "Could not detect Linux distribution" -fi - -# Architecture -ARCH=$(uname -m) -print_info "Architecture: $ARCH" - -# Get kernel version early (needed for distribution-specific checks) -KERNEL_VERSION=$(uname -r) -KERNEL_MAJOR=$(echo "$KERNEL_VERSION" | cut -d. -f1) -KERNEL_MINOR=$(echo "$KERNEL_VERSION" | cut -d. -f2) - -# ============================================================================= -# HARDWARE / VIRTUALIZATION CHECK -# ============================================================================= -print_section "Hardware / Virtualization" - -IS_PHYSICAL=true -IS_CONTAINER=false -IS_VM=false -VIRT_TYPE="" -CONTAINER_TYPE="" -HYPERVISOR="" - -# Check if running in container first -if [ -f /.dockerenv ]; then - IS_CONTAINER=true - CONTAINER_TYPE="Docker" -elif [ -f /run/.containerenv ]; then - IS_CONTAINER=true - CONTAINER_TYPE="Podman" -elif grep -q "container=lxc" /proc/1/environ 2>/dev/null; then - IS_CONTAINER=true - CONTAINER_TYPE="LXC" -elif grep -q "/docker/" /proc/1/cgroup 2>/dev/null; then - IS_CONTAINER=true - CONTAINER_TYPE="Docker" -elif grep -q "/lxc/" /proc/1/cgroup 2>/dev/null; then - IS_CONTAINER=true - CONTAINER_TYPE="LXC" -elif grep -q "/kubepods/" /proc/1/cgroup 2>/dev/null; then - IS_CONTAINER=true - CONTAINER_TYPE="Kubernetes Pod" -elif [ -d /run/systemd/system ] && systemd-detect-virt --container &>/dev/null; then - DETECTED=$(systemd-detect-virt --container 2>/dev/null) - if [ "$DETECTED" != "none" ] && [ -n "$DETECTED" ]; then - IS_CONTAINER=true - CONTAINER_TYPE="$DETECTED" - fi -fi - -# Detect virtualization technology using multiple methods -detect_virtualization() { - # Method 1: systemd-detect-virt (most reliable on systemd systems) - if command -v systemd-detect-virt &> /dev/null; then - DETECTED=$(systemd-detect-virt --vm 2>/dev/null) || true - if [ -n "$DETECTED" ] && [ "$DETECTED" != "none" ]; then - IS_VM=true - IS_PHYSICAL=false - VIRT_TYPE="$DETECTED" - return 0 - fi - fi - - # Method 2: Check DMI/SMBIOS information (may not exist on ARM) - if [ -f /sys/class/dmi/id/product_name ]; then - PRODUCT_NAME=$(cat /sys/class/dmi/id/product_name 2>/dev/null) || true - case "$PRODUCT_NAME" in - *VirtualBox*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="VirtualBox"; return 0 ;; - *VMware*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="VMware"; return 0 ;; - *Virtual\ Machine*|*Hyper-V*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Hyper-V"; return 0 ;; - *KVM*|*QEMU*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="KVM/QEMU"; return 0 ;; - *Bochs*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Bochs"; return 0 ;; - *Parallels*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Parallels"; return 0 ;; - esac - fi - - # Method 3: Check sys_vendor - if [ -f /sys/class/dmi/id/sys_vendor ]; then - SYS_VENDOR=$(cat /sys/class/dmi/id/sys_vendor 2>/dev/null) || true - case "$SYS_VENDOR" in - *VMware*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="VMware"; return 0 ;; - *innotek*|*Oracle*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="VirtualBox"; return 0 ;; - *Xen*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Xen"; return 0 ;; - *Microsoft*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Hyper-V"; return 0 ;; - *QEMU*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="QEMU"; return 0 ;; - *Amazon\ EC2*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Amazon EC2 (Xen/Nitro)"; return 0 ;; - *Google*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Google Cloud"; return 0 ;; - *DigitalOcean*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="DigitalOcean"; return 0 ;; - esac - fi - - # Method 4: Check board_vendor - if [ -f /sys/class/dmi/id/board_vendor ]; then - BOARD_VENDOR=$(cat /sys/class/dmi/id/board_vendor 2>/dev/null) || true - case "$BOARD_VENDOR" in - *Amazon*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Amazon EC2"; return 0 ;; - esac - fi - - # Method 5: Check /proc/cpuinfo for hypervisor flag (x86/x86_64) - if grep -q "^flags.*hypervisor" /proc/cpuinfo 2>/dev/null; then - IS_VM=true - IS_PHYSICAL=false - # Try to identify which hypervisor - if grep -q "^flags.*vmx" /proc/cpuinfo 2>/dev/null; then - VIRT_TYPE="Unknown (nested virt with VMX)" - elif grep -q "^flags.*svm" /proc/cpuinfo 2>/dev/null; then - VIRT_TYPE="Unknown (nested virt with SVM)" - else - VIRT_TYPE="Unknown hypervisor" - fi - return 0 - fi - - # Method 6: Check for Xen via /sys/hypervisor - if [ -f /sys/hypervisor/type ]; then - HYPER_TYPE=$(cat /sys/hypervisor/type 2>/dev/null) || true - if [ -n "$HYPER_TYPE" ]; then - IS_VM=true - IS_PHYSICAL=false - VIRT_TYPE="$HYPER_TYPE" - return 0 - fi - fi - - # Method 7: Check dmesg for virtualization hints (requires root usually) - if command -v dmesg &> /dev/null; then - DMESG_OUTPUT=$(dmesg 2>/dev/null | grep -i "hypervisor detected" | head -1) || true - if [ -n "$DMESG_OUTPUT" ]; then - IS_VM=true - IS_PHYSICAL=false - VIRT_TYPE=$(echo "$DMESG_OUTPUT" | sed 's/.*: //') - return 0 - fi - fi - - # Method 8: Check for virt-what output if available - if command -v virt-what &> /dev/null; then - VIRT_WHAT=$(virt-what 2>/dev/null | head -1) || true - if [ -n "$VIRT_WHAT" ]; then - IS_VM=true - IS_PHYSICAL=false - VIRT_TYPE="$VIRT_WHAT" - return 0 - fi - fi - - # Method 9: ARM-specific checks - if [ "$ARCH" = "aarch64" ] || [ "$ARCH" = "armv7l" ] || [ "$ARCH" = "armv8l" ]; then - # Check for device tree model (Raspberry Pi, etc.) - if [ -f /proc/device-tree/model ]; then - DT_MODEL=$(cat /proc/device-tree/model 2>/dev/null | tr -d '\0') || true - if [ -n "$DT_MODEL" ]; then - HW_MODEL="$DT_MODEL" - fi - fi - - # Check for ARM cloud VMs via device tree - if [ -f /proc/device-tree/hypervisor/compatible ]; then - HYPER_COMPAT=$(cat /proc/device-tree/hypervisor/compatible 2>/dev/null | tr -d '\0') || true - if [ -n "$HYPER_COMPAT" ]; then - IS_VM=true - IS_PHYSICAL=false - VIRT_TYPE="$HYPER_COMPAT" - return 0 - fi - fi - - # Check for common ARM cloud platforms - if [ -f /sys/firmware/devicetree/base/compatible ]; then - DT_COMPAT=$(cat /sys/firmware/devicetree/base/compatible 2>/dev/null | tr '\0' ' ') || true - case "$DT_COMPAT" in - *amazon,graviton*|*aws,nitro*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="AWS Graviton (Nitro)"; return 0 ;; - *google,*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Google Cloud ARM"; return 0 ;; - *azure,*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Azure ARM"; return 0 ;; - *oracle,*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Oracle Cloud ARM"; return 0 ;; - *ampere,*) - IS_VM=true; IS_PHYSICAL=false; VIRT_TYPE="Ampere Cloud"; return 0 ;; - esac - fi - fi - - return 1 -} - -# Run virtualization detection (don't exit on failure) -detect_virtualization || true - -# Get hypervisor info if available -if [ -f /sys/hypervisor/uuid ]; then - HYPERVISOR_UUID=$(cat /sys/hypervisor/uuid 2>/dev/null) || true -fi - -# Display results -if [ "$IS_CONTAINER" = true ]; then - print_info "Environment: Container ($CONTAINER_TYPE)" - if [ "$IS_VM" = true ]; then - print_info "Host appears to be: Virtual Machine ($VIRT_TYPE)" - fi -elif [ "$IS_VM" = true ]; then - print_info "Environment: Virtual Machine" - print_ok "Virtualization: $VIRT_TYPE" - - # Additional VM details - if [ -f /sys/class/dmi/id/product_name ]; then - PRODUCT=$(cat /sys/class/dmi/id/product_name 2>/dev/null) || true - [ -n "$PRODUCT" ] && print_info "Product: $PRODUCT" - fi - if [ -f /sys/class/dmi/id/product_version ]; then - VERSION=$(cat /sys/class/dmi/id/product_version 2>/dev/null) || true - [ -n "$VERSION" ] && [ "$VERSION" != "None" ] && print_info "Version: $VERSION" - fi - if [ -n "$HYPERVISOR_UUID" ]; then - print_info "Hypervisor UUID: $HYPERVISOR_UUID" - fi - - # Check for cloud provider - CLOUD_PROVIDER="" - if [ -f /sys/class/dmi/id/chassis_asset_tag ]; then - CHASSIS_TAG=$(cat /sys/class/dmi/id/chassis_asset_tag 2>/dev/null) || true - case "$CHASSIS_TAG" in - *Amazon*) CLOUD_PROVIDER="AWS" ;; - *Google*) CLOUD_PROVIDER="Google Cloud" ;; - *Azure*) CLOUD_PROVIDER="Microsoft Azure" ;; - esac - fi - if [ -z "$CLOUD_PROVIDER" ] && [ -f /sys/class/dmi/id/board_asset_tag ]; then - BOARD_TAG=$(cat /sys/class/dmi/id/board_asset_tag 2>/dev/null) || true - case "$BOARD_TAG" in - i-*) CLOUD_PROVIDER="AWS (EC2 instance: $BOARD_TAG)" ;; - esac - fi - [ -n "$CLOUD_PROVIDER" ] && print_info "Cloud Provider: $CLOUD_PROVIDER" -else - # Physical machine or unknown - print_ok "Environment: Physical Machine" - - # Show hardware info - HW_INFO_FOUND=false - - # ARM: Check device tree model first - if [ -n "${HW_MODEL:-}" ]; then - print_info "Hardware: $HW_MODEL" - HW_INFO_FOUND=true - elif [ -f /proc/device-tree/model ]; then - DT_MODEL=$(cat /proc/device-tree/model 2>/dev/null | tr -d '\0') || true - if [ -n "$DT_MODEL" ]; then - print_info "Hardware: $DT_MODEL" - HW_INFO_FOUND=true - fi - fi - - # x86: Check DMI info - if [ -f /sys/class/dmi/id/product_name ]; then - PRODUCT=$(cat /sys/class/dmi/id/product_name 2>/dev/null) || true - if [ -n "$PRODUCT" ] && [ "$PRODUCT" != "None" ]; then - print_info "Product: $PRODUCT" - HW_INFO_FOUND=true - fi - fi - if [ -f /sys/class/dmi/id/sys_vendor ]; then - VENDOR=$(cat /sys/class/dmi/id/sys_vendor 2>/dev/null) || true - if [ -n "$VENDOR" ] && [ "$VENDOR" != "None" ]; then - print_info "Vendor: $VENDOR" - HW_INFO_FOUND=true - fi - fi - - # If no hardware info found, show CPU info - if [ "$HW_INFO_FOUND" = false ]; then - if [ -f /proc/cpuinfo ]; then - # ARM: look for Hardware or model name - CPU_MODEL=$(grep -m1 "^model name\|^Hardware\|^Model" /proc/cpuinfo 2>/dev/null | cut -d: -f2 | sed 's/^ //') || true - if [ -n "$CPU_MODEL" ]; then - print_info "CPU: $CPU_MODEL" - fi - fi - fi -fi - -# Check CPU virtualization support -if [ -f /proc/cpuinfo ]; then - # x86/x86_64: Check for VMX/SVM - if grep -q "^flags.*vmx" /proc/cpuinfo 2>/dev/null; then - print_info "CPU Virtualization: Intel VT-x supported" - elif grep -q "^flags.*svm" /proc/cpuinfo 2>/dev/null; then - print_info "CPU Virtualization: AMD-V supported" - # ARM: Check for virtualization extensions - elif [ "$ARCH" = "aarch64" ]; then - # Check CPU features for ARM virtualization - if grep -q "^Features.*:.*" /proc/cpuinfo 2>/dev/null; then - ARM_FEATURES=$(grep -m1 "^Features" /proc/cpuinfo | cut -d: -f2) || true - # ARM VHE (Virtualization Host Extensions) present in ARMv8.1+ - print_info "CPU: ARM64 (hardware virtualization capable)" - fi - fi -fi - -# Note about XDP in VMs -if [ "$IS_VM" = true ]; then - print_warn "Running in VM - XDP hardware offload not available" - print_info " XDP will use driver or SKB mode (still performant)" -fi - -# Distribution-specific notes for moat -case "$DISTRO_ID" in - alpine) - print_warn "Alpine Linux detected - ensure libc6-compat is installed for glibc binaries" - ;; - ubuntu|debian) - if [ "$KERNEL_MAJOR" -lt 5 ]; then - print_info "Consider upgrading to a newer kernel for better eBPF support" - fi - ;; - rhel|centos|rocky|almalinux) - print_info "RHEL-based distribution - ensure kernel-headers and bpftool packages are available" - ;; -esac - -# ============================================================================= -# KERNEL VERSION CHECK -# ============================================================================= -print_section "Kernel Version" - -KERNEL_VERSION=$(uname -r) -KERNEL_MAJOR=$(echo "$KERNEL_VERSION" | cut -d. -f1) -KERNEL_MINOR=$(echo "$KERNEL_VERSION" | cut -d. -f2) - -print_info "Kernel version: $KERNEL_VERSION" - -if version_compare "$KERNEL_MAJOR.$KERNEL_MINOR" "" $MIN_KERNEL_MAJOR $MIN_KERNEL_MINOR; then - print_ok "Kernel version >= ${MIN_KERNEL_MAJOR}.${MIN_KERNEL_MINOR} (required for XDP)" - FEATURES["xdp"]=1 -else - print_fail "Kernel version < ${MIN_KERNEL_MAJOR}.${MIN_KERNEL_MINOR} (XDP requires >= 4.18)" -fi - -# Check for specific kernel features via /boot/config if available -if [ -f "/boot/config-$KERNEL_VERSION" ]; then - print_info "Found kernel config at /boot/config-$KERNEL_VERSION" -elif [ -f "/proc/config.gz" ]; then - print_info "Found kernel config at /proc/config.gz" -fi - -# ============================================================================= -# GLIBC VERSION CHECK -# ============================================================================= -print_section "glibc Version" - -if command -v ldd &> /dev/null; then - GLIBC_VERSION=$(ldd --version 2>&1 | head -n1 | grep -oE '[0-9]+\.[0-9]+' | head -1) - if [ -n "$GLIBC_VERSION" ]; then - print_info "glibc version: $GLIBC_VERSION" - if version_compare "$GLIBC_VERSION" "" $MIN_GLIBC_MAJOR $MIN_GLIBC_MINOR; then - print_ok "glibc version >= ${MIN_GLIBC_MAJOR}.${MIN_GLIBC_MINOR}" - else - print_warn "glibc version < ${MIN_GLIBC_MAJOR}.${MIN_GLIBC_MINOR}" - fi - else - print_warn "Could not determine glibc version" - fi -else - print_warn "ldd not available, cannot check glibc version" -fi - -# Check musl as alternative -if command -v musl-ldd &> /dev/null || [ -f /lib/ld-musl-*.so.1 ]; then - MUSL_VERSION=$(ls /lib/ld-musl-*.so.1 2>/dev/null | head -1) - if [ -n "$MUSL_VERSION" ]; then - print_info "musl libc detected: $MUSL_VERSION" - fi -fi - -# ============================================================================= -# BPF/EBPF SUPPORT CHECK (without bpftool) -# ============================================================================= -print_section "eBPF Support" - -# Check if BPF syscall is available via /proc/kallsyms -if [ -f /proc/kallsyms ]; then - if grep -q "bpf_prog_load" /proc/kallsyms 2>/dev/null || grep -q " bpf_" /proc/kallsyms 2>/dev/null; then - print_ok "BPF syscall symbols found in kernel" - FEATURES["bpf_syscall"]=1 - else - # Try alternative check - /sys/fs/bpf - if [ -d /sys/fs/bpf ]; then - print_ok "BPF filesystem available at /sys/fs/bpf" - FEATURES["bpf_syscall"]=1 - else - print_fail "BPF syscall not available" - fi - fi -else - # Fallback: check /sys/fs/bpf - if [ -d /sys/fs/bpf ]; then - print_ok "BPF filesystem available at /sys/fs/bpf" - FEATURES["bpf_syscall"]=1 - else - print_warn "Cannot verify BPF support (/proc/kallsyms not readable)" - fi -fi - -# Check for BPF JIT -if [ -f /proc/sys/net/core/bpf_jit_enable ]; then - BPF_JIT=$(cat /proc/sys/net/core/bpf_jit_enable 2>/dev/null) - if [ "$BPF_JIT" = "1" ] || [ "$BPF_JIT" = "2" ]; then - print_ok "BPF JIT is enabled (value: $BPF_JIT)" - FEATURES["bpf_jit"]=1 - else - print_warn "BPF JIT is disabled (value: $BPF_JIT)" - print_info " Enable with: echo 1 > /proc/sys/net/core/bpf_jit_enable" - fi -else - print_warn "BPF JIT sysctl not found" -fi - -# Check for BTF (BPF Type Format) support -if [ -f /sys/kernel/btf/vmlinux ]; then - print_ok "BTF (BPF Type Format) is available" - FEATURES["btf"]=1 -else - print_warn "BTF not available (CO-RE features may be limited)" - print_info " BTF requires CONFIG_DEBUG_INFO_BTF=y in kernel config" -fi - -# Check BPF hardening settings -if [ -f /proc/sys/kernel/unprivileged_bpf_disabled ]; then - UNPRIV_BPF=$(cat /proc/sys/kernel/unprivileged_bpf_disabled 2>/dev/null) - case "$UNPRIV_BPF" in - 0) - print_info "Unprivileged BPF: allowed" - ;; - 1) - print_info "Unprivileged BPF: disabled (root required)" - ;; - 2) - print_info "Unprivileged BPF: permanently disabled" - ;; - *) - print_info "Unprivileged BPF setting: $UNPRIV_BPF" - ;; - esac -fi - -# Check BPF JIT hardening -if [ -f /proc/sys/net/core/bpf_jit_harden ]; then - BPF_HARDEN=$(cat /proc/sys/net/core/bpf_jit_harden 2>/dev/null) - case "$BPF_HARDEN" in - 0) - print_info "BPF JIT hardening: disabled" - ;; - 1) - print_info "BPF JIT hardening: enabled for unprivileged" - ;; - 2) - print_info "BPF JIT hardening: enabled for all" - ;; - esac -fi - -# Check memlock limit (critical for BPF maps) -print_section "Memory Limits" - -MEMLOCK_SUFFICIENT=false -MIN_MEMLOCK_KB=65536 # 64MB minimum recommended for BPF - -# Method 1: Check ulimit -if command -v ulimit &> /dev/null; then - MEMLOCK_SOFT=$(ulimit -l 2>/dev/null) || true - MEMLOCK_HARD=$(ulimit -Hl 2>/dev/null) || true - - if [ "$MEMLOCK_SOFT" = "unlimited" ]; then - print_ok "Memlock (soft limit): unlimited" - MEMLOCK_SUFFICIENT=true - elif [ -n "$MEMLOCK_SOFT" ] && [ "$MEMLOCK_SOFT" -ge "$MIN_MEMLOCK_KB" ] 2>/dev/null; then - print_ok "Memlock (soft limit): ${MEMLOCK_SOFT} KB (sufficient)" - MEMLOCK_SUFFICIENT=true - elif [ -n "$MEMLOCK_SOFT" ]; then - print_warn "Memlock (soft limit): ${MEMLOCK_SOFT} KB (may be too low for BPF)" - print_info " Recommended: >= ${MIN_MEMLOCK_KB} KB or unlimited" - fi - - if [ "$MEMLOCK_HARD" = "unlimited" ]; then - print_info "Memlock (hard limit): unlimited" - elif [ -n "$MEMLOCK_HARD" ]; then - print_info "Memlock (hard limit): ${MEMLOCK_HARD} KB" - fi -fi - -# Method 2: Check /proc/self/limits for more detail -if [ -f /proc/self/limits ]; then - LIMITS_LINE=$(grep "Max locked memory" /proc/self/limits 2>/dev/null) || true - if [ -n "$LIMITS_LINE" ]; then - SOFT_BYTES=$(echo "$LIMITS_LINE" | awk '{print $4}') - HARD_BYTES=$(echo "$LIMITS_LINE" | awk '{print $5}') - - if [ "$SOFT_BYTES" = "unlimited" ]; then - print_info "Process memlock: unlimited" - MEMLOCK_SUFFICIENT=true - elif [ -n "$SOFT_BYTES" ] && [ "$SOFT_BYTES" != "unlimited" ]; then - SOFT_MB=$((SOFT_BYTES / 1024 / 1024)) - print_info "Process memlock: ${SOFT_MB} MB (${SOFT_BYTES} bytes)" - fi - fi -fi - -# Method 3: Check systemd default if applicable -if [ -f /etc/systemd/system.conf ]; then - SYSTEMD_MEMLOCK=$(grep -E "^DefaultLimitMEMLOCK=" /etc/systemd/system.conf 2>/dev/null | cut -d= -f2) || true - if [ -n "$SYSTEMD_MEMLOCK" ]; then - print_info "Systemd DefaultLimitMEMLOCK: $SYSTEMD_MEMLOCK" - fi -fi - -# Provide fix instructions if memlock is insufficient -if [ "$MEMLOCK_SUFFICIENT" = false ]; then - print_warn "Memlock limit may be insufficient for BPF operations" - print_info " To fix temporarily: ulimit -l unlimited" - print_info " To fix permanently, add to /etc/security/limits.conf:" - print_info " * soft memlock unlimited" - print_info " * hard memlock unlimited" - print_info " Or for systemd services, add to unit file:" - print_info " LimitMEMLOCK=infinity" -fi - -# ============================================================================= -# KERNEL BPF CONFIG PARAMETERS -# ============================================================================= -print_section "Kernel BPF Configuration" - -check_kernel_config() { - local config_name="$1" - local description="$2" - local config_file="" - - # Find kernel config - if [ -f "/boot/config-$KERNEL_VERSION" ]; then - config_file="/boot/config-$KERNEL_VERSION" - elif [ -f "/proc/config.gz" ]; then - if command -v zcat &> /dev/null; then - if zcat /proc/config.gz 2>/dev/null | grep -q "^${config_name}="; then - print_ok "$config_name ($description)" - return 0 - elif zcat /proc/config.gz 2>/dev/null | grep -q "^# ${config_name} is not set"; then - print_fail "$config_name ($description) - not enabled" - return 1 - fi - fi - return 2 - fi - - if [ -n "$config_file" ]; then - if grep -q "^${config_name}=y" "$config_file" 2>/dev/null; then - print_ok "$config_name ($description)" - return 0 - elif grep -q "^${config_name}=m" "$config_file" 2>/dev/null; then - print_ok "$config_name ($description) [module]" - return 0 - elif grep -q "^# ${config_name} is not set" "$config_file" 2>/dev/null; then - print_fail "$config_name ($description) - not enabled" - return 1 - fi - fi - - print_info "$config_name ($description) - cannot verify" - return 2 -} - -# Required BPF configs -check_kernel_config "CONFIG_BPF" "Basic BPF support" -check_kernel_config "CONFIG_BPF_SYSCALL" "BPF system calls" -check_kernel_config "CONFIG_BPF_JIT" "BPF JIT compiler" -check_kernel_config "CONFIG_HAVE_EBPF_JIT" "eBPF JIT support" - -# XDP specific configs -check_kernel_config "CONFIG_XDP_SOCKETS" "XDP socket support" -check_kernel_config "CONFIG_NET_CLS_BPF" "BPF classifier" -check_kernel_config "CONFIG_NET_ACT_BPF" "BPF action module" - -# BTF and debugging -check_kernel_config "CONFIG_DEBUG_INFO_BTF" "BTF debug info" -check_kernel_config "CONFIG_BPF_LSM" "BPF LSM support" - -# Additional useful configs -check_kernel_config "CONFIG_CGROUP_BPF" "Cgroup BPF support" -check_kernel_config "CONFIG_BPF_STREAM_PARSER" "BPF stream parser" - -# ============================================================================= -# XDP SUPPORT CHECK -# ============================================================================= -print_section "XDP Support" - -# Check for XDP support in network drivers -if [ -d /sys/class/net ]; then - print_info "Available network interfaces:" - for iface in /sys/class/net/*; do - iface_name=$(basename "$iface") - if [ "$iface_name" != "lo" ]; then - driver="" - if [ -L "$iface/device/driver" ]; then - driver=$(basename "$(readlink "$iface/device/driver")") - fi - - # Check for XDP support indicators - xdp_mode="unknown" - if [ -f "$iface/xdp" ] || [ -d "$iface/xdp" ]; then - xdp_mode="supported" - fi - - if [ -n "$driver" ]; then - print_info " $iface_name (driver: $driver, xdp: $xdp_mode)" - else - print_info " $iface_name (xdp: $xdp_mode)" - fi - fi - done -fi - -# Check XDP modes availability (based on kernel version) -if [ "$KERNEL_MAJOR" -ge 5 ] || ([ "$KERNEL_MAJOR" -eq 4 ] && [ "$KERNEL_MINOR" -ge 18 ]); then - print_ok "XDP SKB (generic) mode supported" - FEATURES["xdp_skb"]=1 -fi - -if [ "$KERNEL_MAJOR" -ge 5 ] || ([ "$KERNEL_MAJOR" -eq 4 ] && [ "$KERNEL_MINOR" -ge 18 ]); then - print_ok "XDP driver mode potentially supported (driver dependent)" - FEATURES["xdp_driver"]=1 -fi - -if [ "$KERNEL_MAJOR" -ge 5 ]; then - print_ok "XDP hardware offload potentially supported (NIC dependent)" - FEATURES["xdp_hardware"]=1 -fi - -# ============================================================================= -# NFTABLES CHECK -# ============================================================================= -print_section "nftables Support" - -if command -v nft &> /dev/null; then - NFT_VERSION=$(nft --version 2>&1 | head -1) - print_ok "nft command available: $NFT_VERSION" - FEATURES["nftables"]=1 - - # Check if we can list tables (may need root) - if nft list tables &> /dev/null; then - print_ok "nft can list tables (have permissions)" - - # Check for existing moat/synapse tables - if nft list tables 2>/dev/null | grep -q "synapse"; then - print_info "Found existing 'synapse' table" - fi - else - print_warn "nft cannot list tables (may need root privileges)" - fi -else - print_fail "nft command not found" - print_info " Install with: apt install nftables (Debian/Ubuntu)" - print_info " Or: yum install nftables (RHEL/CentOS)" -fi - -# Check nftables kernel module -if [ -f /proc/modules ]; then - if grep -q "^nf_tables" /proc/modules 2>/dev/null; then - print_ok "nf_tables kernel module loaded" - else - # Module might be built-in - if [ -d /sys/module/nf_tables ]; then - print_ok "nf_tables module available" - else - print_warn "nf_tables module not loaded" - print_info " Load with: modprobe nf_tables" - fi - fi -fi - -# ============================================================================= -# IPTABLES CHECK -# ============================================================================= -print_section "iptables Support" - -# IPv4 iptables -if command -v iptables &> /dev/null; then - IPT_VERSION=$(iptables --version 2>&1) - print_ok "iptables available: $IPT_VERSION" - FEATURES["iptables"]=1 - - # Check if it's nft backend or legacy - if echo "$IPT_VERSION" | grep -q "nf_tables"; then - print_info " Using nftables backend (iptables-nft)" - elif echo "$IPT_VERSION" | grep -q "legacy"; then - print_info " Using legacy backend (iptables-legacy)" - fi - - # Check permissions - if iptables -L -n &> /dev/null; then - print_ok "iptables can list rules (have permissions)" - - # Check for existing SYNAPSE_BLOCK chain - if iptables -L SYNAPSE_BLOCK &> /dev/null; then - print_info "Found existing 'SYNAPSE_BLOCK' chain" - fi - else - print_warn "iptables cannot list rules (may need root privileges)" - fi -else - print_fail "iptables command not found" -fi - -# IPv6 ip6tables -if command -v ip6tables &> /dev/null; then - IP6T_VERSION=$(ip6tables --version 2>&1) - print_ok "ip6tables available: $IP6T_VERSION" - FEATURES["ip6tables"]=1 - - if ip6tables -L -n &> /dev/null; then - print_ok "ip6tables can list rules (have permissions)" - else - print_warn "ip6tables cannot list rules (may need root privileges)" - fi -else - print_warn "ip6tables command not found (IPv6 filtering via iptables unavailable)" -fi - -# ============================================================================= -# IPV6 SUPPORT -# ============================================================================= -print_section "IPv6 Support" - -if [ -f /proc/sys/net/ipv6/conf/all/disable_ipv6 ]; then - IPV6_DISABLED=$(cat /proc/sys/net/ipv6/conf/all/disable_ipv6 2>/dev/null) - if [ "$IPV6_DISABLED" = "0" ]; then - print_ok "IPv6 is enabled system-wide" - FEATURES["ipv6"]=1 - else - print_warn "IPv6 is disabled system-wide" - print_info " Note: XDP may require IPv6 enabled per-interface" - print_info " Synapse can enable IPv6 per-interface automatically" - fi -else - print_warn "Cannot determine IPv6 status" -fi - -# Check per-interface IPv6 -if [ -d /sys/class/net ]; then - ipv6_interfaces="" - for iface in /sys/class/net/*; do - iface_name=$(basename "$iface") - if [ "$iface_name" != "lo" ]; then - if [ -f "/proc/sys/net/ipv6/conf/$iface_name/disable_ipv6" ]; then - ipv6_val=$(cat "/proc/sys/net/ipv6/conf/$iface_name/disable_ipv6" 2>/dev/null) - if [ "$ipv6_val" = "0" ]; then - ipv6_interfaces="$ipv6_interfaces $iface_name(enabled)" - else - ipv6_interfaces="$ipv6_interfaces $iface_name(disabled)" - fi - fi - fi - done - if [ -n "$ipv6_interfaces" ]; then - print_info "Per-interface IPv6:$ipv6_interfaces" - fi -fi - -# ============================================================================= -# CAPABILITIES CHECK -# ============================================================================= -print_section "Required Capabilities" - -print_info "Synapse requires the following capabilities:" -print_info " CAP_SYS_ADMIN - Load/manage BPF programs" -print_info " CAP_NET_ADMIN - Network administration" -print_info " CAP_BPF - BPF operations (kernel >= 5.8)" -print_info " CAP_PERFMON - Performance monitoring" -print_info " CAP_SYS_RESOURCE - Unlimited locked memory" - -# Check if running as root -if [ "$(id -u)" = "0" ]; then - print_ok "Running as root (all capabilities available)" -else - print_warn "Not running as root" - - # Check current capabilities if capsh is available - if command -v capsh &> /dev/null; then - print_info "Current capabilities:" - capsh --print 2>/dev/null | grep "Current" | head -1 || true - fi -fi - -# Check locked memory limit -if command -v ulimit &> /dev/null; then - MEMLOCK=$(ulimit -l 2>/dev/null) - if [ "$MEMLOCK" = "unlimited" ]; then - print_ok "Locked memory limit: unlimited" - else - print_warn "Locked memory limit: ${MEMLOCK}KB" - print_info " Consider: ulimit -l unlimited" - fi -fi - -# ============================================================================= -# SUMMARY -# ============================================================================= -print_header "Feature Summary" - -echo "" -echo -e "${BOLD}Core Features:${NC}" -echo -e " XDP Support: $([ ${FEATURES["xdp"]} -eq 1 ] && echo -e "${GREEN}Available${NC}" || echo -e "${RED}Not Available${NC}")" -echo -e " - Hardware Offload: $([ ${FEATURES["xdp_hardware"]} -eq 1 ] && echo -e "${GREEN}Possible${NC}" || echo -e "${YELLOW}Unlikely${NC}")" -echo -e " - Driver Mode: $([ ${FEATURES["xdp_driver"]} -eq 1 ] && echo -e "${GREEN}Possible${NC}" || echo -e "${YELLOW}Unlikely${NC}")" -echo -e " - SKB/Generic Mode: $([ ${FEATURES["xdp_skb"]} -eq 1 ] && echo -e "${GREEN}Available${NC}" || echo -e "${RED}Not Available${NC}")" -echo -e " BPF Syscall: $([ ${FEATURES["bpf_syscall"]} -eq 1 ] && echo -e "${GREEN}Available${NC}" || echo -e "${RED}Not Available${NC}")" -echo -e " BPF JIT: $([ ${FEATURES["bpf_jit"]} -eq 1 ] && echo -e "${GREEN}Enabled${NC}" || echo -e "${YELLOW}Disabled${NC}")" -echo -e " BTF Support: $([ ${FEATURES["btf"]} -eq 1 ] && echo -e "${GREEN}Available${NC}" || echo -e "${YELLOW}Not Available${NC}")" - -echo "" -echo -e "${BOLD}Fallback Firewalls:${NC}" -echo -e " nftables: $([ ${FEATURES["nftables"]} -eq 1 ] && echo -e "${GREEN}Available${NC}" || echo -e "${RED}Not Available${NC}")" -echo -e " iptables (IPv4): $([ ${FEATURES["iptables"]} -eq 1 ] && echo -e "${GREEN}Available${NC}" || echo -e "${RED}Not Available${NC}")" -echo -e " ip6tables (IPv6): $([ ${FEATURES["ip6tables"]} -eq 1 ] && echo -e "${GREEN}Available${NC}" || echo -e "${RED}Not Available${NC}")" - -echo "" -echo -e "${BOLD}Network:${NC}" -echo -e " IPv6: $([ ${FEATURES["ipv6"]} -eq 1 ] && echo -e "${GREEN}Enabled${NC}" || echo -e "${YELLOW}Disabled${NC}")" - -echo "" -echo -e "${BOLD}Synapse Operational Mode:${NC}" -if [ ${FEATURES["xdp"]} -eq 1 ] && [ ${FEATURES["bpf_syscall"]} -eq 1 ]; then - echo -e " ${GREEN}XDP mode available${NC} - Best performance" - echo -e " Synapse will use XDP for packet filtering with fallback chain:" - echo -e " 1. Hardware offload (if NIC supports)" - echo -e " 2. Driver mode (if driver supports)" - echo -e " 3. SKB/Generic mode (guaranteed fallback)" -elif [ ${FEATURES["nftables"]} -eq 1 ]; then - echo -e " ${YELLOW}nftables fallback mode${NC}" - echo -e " XDP not available, moat will use nftables for filtering" -elif [ ${FEATURES["iptables"]} -eq 1 ]; then - echo -e " ${YELLOW}iptables fallback mode${NC}" - echo -e " XDP and nftables not available, moat will use iptables" -else - echo -e " ${RED}No packet filtering available!${NC}" - echo -e " Synapse cannot operate without XDP, nftables, or iptables" -fi - -echo "" -echo -e "${BOLD}Recommendations:${NC}" - -if [ ${FEATURES["bpf_jit"]} -eq 0 ]; then - echo -e " ${YELLOW}*${NC} Enable BPF JIT for better performance:" - echo -e " echo 1 > /proc/sys/net/core/bpf_jit_enable" -fi - -if [ ${FEATURES["btf"]} -eq 0 ]; then - echo -e " ${YELLOW}*${NC} Rebuild kernel with CONFIG_DEBUG_INFO_BTF=y for CO-RE support" -fi - -if [ ${FEATURES["ipv6"]} -eq 0 ]; then - echo -e " ${YELLOW}*${NC} Consider enabling IPv6 (XDP may require it per-interface)" -fi - -if [ "$(id -u)" != "0" ]; then - echo -e " ${YELLOW}*${NC} Run moat as root or with required capabilities" -fi - -# No recommendations needed -if [ ${FEATURES["xdp"]} -eq 1 ] && [ ${FEATURES["bpf_syscall"]} -eq 1 ] && \ - [ ${FEATURES["bpf_jit"]} -eq 1 ] && [ ${FEATURES["btf"]} -eq 1 ]; then - echo -e " ${GREEN}System is optimally configured for moat!${NC}" -fi - -echo "" -print_header "Diagnostic Complete" diff --git a/pkg/debug/local_build.sh b/pkg/debug/local_build.sh deleted file mode 100644 index 5eb720f..0000000 --- a/pkg/debug/local_build.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -set -euo pipefail - -PLATFORM=${1:-"linux/amd64"} -LEGACY=${2:-"false"} -BUILD_FLAGS=${3:-""} - -docker rm -f temp-container - -if [ "$LEGACY" = "true" ]; then - DOCKERFILE="pkg/debug/build-legacy.Dockerfile" -else - DOCKERFILE="pkg/debug/build.Dockerfile" -fi - -BUILD_OPTS=() -[ -n "$PLATFORM" ] && BUILD_OPTS+=(--platform "$PLATFORM") -[ -n "$BUILD_FLAGS" ] && BUILD_OPTS+=(--build-arg "BUILD_FLAGS=$BUILD_FLAGS") -docker buildx build "${BUILD_OPTS[@]}" --load -f $DOCKERFILE -t synapse-builder . -docker create --name temp-container synapse-builder -docker cp temp-container:/output/synapse ./synapse From 82868513f95172d2d7864f0c9af2f084af78ee2e Mon Sep 17 00:00:00 2001 From: krichard1212 <136473183+krichard1212@users.noreply.github.com> Date: Tue, 27 Jan 2026 13:24:18 +0100 Subject: [PATCH 03/14] conflict error --- src/main.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/main.rs b/src/main.rs index f9dc58b..6cda43c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::mem::MaybeUninit; use std::sync::Arc; use std::str::FromStr; use std::fs::File; @@ -271,8 +270,10 @@ async fn async_main(args: Args, config: Config) -> Result<()> { libbpf_rs::set_print(None); for iface in &iface_names { - let boxed_open: Box> = Box::new(MaybeUninit::uninit()); - let open_object: &'static mut MaybeUninit = Box::leak(boxed_open); + let boxed_open: Box> = + Box::new(std::mem::MaybeUninit::uninit()); + let open_object: &'static mut std::mem::MaybeUninit = + Box::leak(boxed_open); let skel_builder = bpf::FilterSkelBuilder::default(); match skel_builder.open(open_object).and_then(|o| o.load()) { Ok(mut skel) => { From 3a6633ae49b7ccd8abe8ae870b1b356e045295f3 Mon Sep 17 00:00:00 2001 From: krichard1212 <136473183+krichard1212@users.noreply.github.com> Date: Tue, 27 Jan 2026 13:24:36 +0100 Subject: [PATCH 04/14] fix: resolve conflict error and improve error handling --- config/synapse.docker.yaml | 26 ++++++++++++++++++++++++++ config/synapse.placeholder.yaml | 30 ++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 config/synapse.docker.yaml create mode 100644 config/synapse.placeholder.yaml diff --git a/config/synapse.docker.yaml b/config/synapse.docker.yaml new file mode 100644 index 0000000..baf7bb6 --- /dev/null +++ b/config/synapse.docker.yaml @@ -0,0 +1,26 @@ +# Synapse Docker config for local Arxignis +mode: "agent" + +network: + iface: "" + ifaces: [] + disable_xdp: true + ip_version: "both" + firewall_mode: "none" + +platform: + api_key: "local-test-key" + base_url: "http://host.docker.internal:8080/v1" + log_sending_enabled: true + include_response_body: false + max_body_size: 1048576 + +logging: + level: "info" + +pingora: + proxy_address_http: "0.0.0.0:8080" + proxy_address_tls: null + upstreams_conf: "/app/upstreams.yaml" + config_address: "0.0.0.0:9090" + config_api_enabled: false diff --git a/config/synapse.placeholder.yaml b/config/synapse.placeholder.yaml new file mode 100644 index 0000000..d112d36 --- /dev/null +++ b/config/synapse.placeholder.yaml @@ -0,0 +1,30 @@ +# Synapse placeholder config for container usage +mode: "agent" + +network: + iface: "" + ifaces: [] + disable_xdp: true + ip_version: "both" + firewall_mode: "none" + +platform: + api_key: "REPLACE_ME" + base_url: "https://api.gen0sec.com/v1" + log_sending_enabled: true + include_response_body: false + max_body_size: 1048576 + +logging: + level: "info" + +redis: + url: "" + prefix: "g0s:synapse" + +pingora: + proxy_address_http: "0.0.0.0:8080" + proxy_address_tls: null + upstreams_conf: "/app/upstreams.yaml" + config_address: "0.0.0.0:9090" + config_api_enabled: false From 7e6612edba46e5e16465483bff13a2fb48664bee Mon Sep 17 00:00:00 2001 From: krichard1212 <136473183+krichard1212@users.noreply.github.com> Date: Wed, 28 Jan 2026 13:27:41 +0100 Subject: [PATCH 05/14] agent only build, agent id improvements --- .github/workflows/build.yaml | 13 +++ Cargo.toml | 62 ++++++---- src/main.rs | 219 ++++++++++++++++++++++------------- src/threat/mod.rs | 114 ++++++++++++++++-- src/utils.rs | 9 ++ src/utils/structs.rs | 10 +- src/waf/wirefilter.rs | 1 + src/worker/mod.rs | 1 + 8 files changed, 320 insertions(+), 109 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index dfc453a..a8d9516 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -5,6 +5,19 @@ on: - main jobs: + build-agent-only: + name: Build agent-only (no proxy) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + - name: Build agent-only + shell: bash + run: | + set -euxo pipefail + cargo build --locked --no-default-features --features agent-only + build-amd64: name: Build amd64 runs-on: ubuntu-latest diff --git a/Cargo.toml b/Cargo.toml index 81a17b9..39a8a38 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ rustls = { version = "0.23.36", default-features = false, features = [ "ring", "logging", ] } -rustls-pemfile = "2.1.2" +rustls-pemfile = { version = "2.1.2", optional = true } webpki-roots = "1.0" base64 = "0.22" async-trait = "0.1.81" @@ -71,7 +71,7 @@ uuid = { version = "1.19", features = ["v4", "serde"] } url = "2.5" clamav-tcp = "0.2" multer = "3.0" -proxy-protocol = { git = "https://github.com/gen0sec/proxy-protocol", rev = "ac28b27d317088f0e9e89805ada3b9f5cfbf5673"} +proxy-protocol = { git = "https://github.com/gen0sec/proxy-protocol", rev = "ac28b27d317088f0e9e89805ada3b9f5cfbf5673", optional = true } rand = "0.9" regex = "1.0" daemonize = "0.5.0" @@ -87,35 +87,35 @@ libbpf-rs = { version = "0.25.0", optional = true } # pingora-memory-cache = { path = "../pingora/pingora-memory-cache" } wirefilter-engine = { git = "https://github.com/gen0sec/wirefilter" , rev = "ab901470a24aad789cb9c03dd214d6c7d4cab589" } -pingora = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81", features = ["lb", "openssl", "proxy"] } -pingora-core = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81"} -pingora-proxy = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81"} -pingora-limits = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81"} -pingora-http = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81"} -pingora-memory-cache = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81"} +pingora = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81", features = ["lb", "openssl", "proxy"], optional = true } +pingora-core = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81", optional = true } +pingora-proxy = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81", optional = true } +pingora-limits = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81", optional = true } +pingora-http = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81", optional = true } +pingora-memory-cache = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81", optional = true } mimalloc = { version = "0.1.48", default-features = false } dashmap = "7.0.0-rc2" ctrlc = "3.5.0" arc-swap = "1.7.1" -prometheus = "0.14.0" +prometheus = { version = "0.14.0", optional = true } once_cell = "1.21.3" maxminddb = "0.27" memmap2 = "0.9" -axum-server = { version = "0.8.0", features = ["tls-openssl"] } -axum = { version = "0.8.8" } -tower-http = { version = "0.6.8", features = ["fs"] } -tonic = "0.14.2" -port_check = "0.3.0" +axum-server = { version = "0.8.0", features = ["tls-openssl"], optional = true } +axum = { version = "0.8.8", optional = true } +tower-http = { version = "0.6.8", features = ["fs"], optional = true } +tonic = { version = "0.14.2", optional = true } +port_check = { version = "0.3.0", optional = true } notify = "8.2.0" -privdrop = "0.5.6" +privdrop = { version = "0.5.6", optional = true } base16ct = { version = "0.3.0", features = ["alloc"] } nftables = "0.6" iptables = "0.5" -actix-web = "4.12" -actix-files = "0.6" -instant-acme = "0.8" -trust-dns-resolver = "0.23.2" +actix-web = { version = "4.12", optional = true } +actix-files = { version = "0.6", optional = true } +instant-acme = { version = "0.8", optional = true } +trust-dns-resolver = { version = "0.23.2", optional = true } tracing = "0.1" tracing-subscriber = "0.3" @@ -123,6 +123,28 @@ tracing-subscriber = "0.3" serial_test = "3.3" [features] -default = ["bpf"] +default = ["bpf", "proxy"] +proxy = [ + "dep:actix-files", + "dep:actix-web", + "dep:axum", + "dep:axum-server", + "dep:instant-acme", + "dep:pingora", + "dep:pingora-core", + "dep:pingora-http", + "dep:pingora-limits", + "dep:pingora-memory-cache", + "dep:pingora-proxy", + "dep:port_check", + "dep:privdrop", + "dep:prometheus", + "dep:proxy-protocol", + "dep:rustls-pemfile", + "dep:tonic", + "dep:tower-http", + "dep:trust-dns-resolver", +] +agent-only = [] bpf = ["dep:libbpf-rs", "dep:libbpf-cargo", "dep:vmlinux"] disable-bpf = [] diff --git a/src/main.rs b/src/main.rs index 41adc81..649631d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,6 +21,7 @@ pub mod access_log; pub mod access_rules; pub mod agent_status; pub mod app_state; +#[cfg(feature = "proxy")] pub mod captcha_server; pub mod cli; pub mod content_scanning; @@ -33,8 +34,10 @@ pub mod http_client; pub mod waf; pub mod threat; pub mod redis; +#[cfg(feature = "proxy")] pub mod proxy_protocol; pub mod authcheck; +#[cfg(feature = "proxy")] pub mod http_proxy; #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] pub mod bpf { @@ -53,6 +56,7 @@ pub mod bpf_stats; pub mod ja4_plus; pub mod utils; pub mod worker; +#[cfg(feature = "proxy")] pub mod acme; use tokio::signal; @@ -88,38 +92,48 @@ fn main() -> Result<()> { let args = Args::parse(); // Handle clear certificate command (runs before loading full config) if let Some(certificate_name) = &args.clear_certificate { - // Initialize minimal runtime for async operations - let rt = tokio::runtime::Runtime::new() - .context("Failed to create tokio runtime")?; - - // Load minimal config for Redis connection - let config = Config::load_from_args(&args) - .context("Failed to load configuration")?; - - // Initialize Redis if configured - if !config.redis.url.is_empty() { - rt.block_on(crate::redis::RedisManager::init( - &config.redis.url, - config.redis.prefix.clone(), - config.redis.ssl.as_ref(), - )) - .context("Failed to initialize Redis manager")?; - } + #[cfg(feature = "proxy")] + { + // Initialize minimal runtime for async operations + let rt = tokio::runtime::Runtime::new() + .context("Failed to create tokio runtime")?; + + // Load minimal config for Redis connection + let config = Config::load_from_args(&args) + .context("Failed to load configuration")?; + + // Initialize Redis if configured + if !config.redis.url.is_empty() { + rt.block_on(crate::redis::RedisManager::init( + &config.redis.url, + config.redis.prefix.clone(), + config.redis.ssl.as_ref(), + )) + .context("Failed to initialize Redis manager")?; + } - // Get certificate path from config - let certificate_path = config - .pingora - .proxy_certificates - .clone() - .unwrap_or_else(|| "/etc/synapse/certs".to_string()); + // Get certificate path from config + let certificate_path = config + .pingora + .proxy_certificates + .clone() + .unwrap_or_else(|| "/etc/synapse/certs".to_string()); - // Clear the certificate - rt.block_on(crate::worker::certificate::clear_certificate( - certificate_name, - &certificate_path, - ))?; + // Clear the certificate + rt.block_on(crate::worker::certificate::clear_certificate( + certificate_name, + &certificate_path, + ))?; - return Ok(()); + return Ok(()); + } + + #[cfg(not(feature = "proxy"))] + { + return Err(anyhow::anyhow!( + "clear-certificate is not available in agent-only builds" + )); + } } // API key is optional - allow running in local mode without it @@ -128,6 +142,12 @@ fn main() -> Result<()> { let config = Config::load_from_args(&args) .context("Failed to load configuration")?; + if config.mode == "proxy" && !cfg!(feature = "proxy") { + return Err(anyhow::anyhow!( + "proxy mode is not supported in agent-only builds (build with the `proxy` feature)" + )); + } + // Handle daemonization before starting tokio runtime if config.daemon.enabled { let stdout = File::create(&config.daemon.stdout) @@ -477,16 +497,20 @@ async fn async_main(args: Args, config: Config) -> Result<()> { }; // Start the captcha verification server in a separate task (skip in agent mode to save memory) - let captcha_server_enabled = config.mode != "agent"; + let captcha_server_enabled = cfg!(feature = "proxy") && config.mode != "agent"; if captcha_server_enabled { - tokio::spawn(async move { - if let Err(e) = captcha_server::start_captcha_server().await { - error!("Captcha server error: {}", e); - } - }); + #[cfg(feature = "proxy")] + { + tokio::spawn(async move { + if let Err(e) = captcha_server::start_captcha_server().await { + error!("Captcha server error: {}", e); + } + }); + } } // Start embedded ACME server if enabled (skip in agent mode - no TLS termination needed) + #[cfg(feature = "proxy")] if config.acme.enabled && config.mode != "agent" { let acme_config = config.acme.clone(); let pingora_config = config.pingora.clone(); @@ -590,49 +614,52 @@ async fn async_main(args: Args, config: Config) -> Result<()> { // Initialize worker manager let (mut worker_manager, _worker_shutdown_rx) = worker::WorkerManager::new(); - // Set ACME config for certificate worker to use (skip in agent mode) - if config.mode != "agent" { - worker::certificate::set_acme_config(config.acme.clone()); - } + #[cfg(feature = "proxy")] + { + // Set ACME config for certificate worker to use (skip in agent mode) + if config.mode != "agent" { + worker::certificate::set_acme_config(config.acme.clone()); + } - // Register certificate worker only if Redis was successfully initialized (skip in agent mode) - if redis_initialized && config.mode != "agent" { - // Parse proxy_certificates from config file (under pingora section) - let certificate_path = if let Some(config_path) = &args.config { - std::fs::read_to_string(config_path) - .ok() - .and_then(|content| serde_yaml::from_str::(&content).ok()) - .and_then(|yaml| { - // Try pingora.proxy_certificates first, then fallback to root level - yaml.get("pingora") - .and_then(|pingora| pingora.get("proxy_certificates")) - .or_else(|| yaml.get("proxy_certificates")) - .and_then(|v| v.as_str().map(|s| s.to_string())) - }) - .unwrap_or_else(|| "/tmp/synapse-certs".to_string()) - } else { - "/tmp/synapse-certs".to_string() - }; + // Register certificate worker only if Redis was successfully initialized (skip in agent mode) + if redis_initialized && config.mode != "agent" { + // Parse proxy_certificates from config file (under pingora section) + let certificate_path = if let Some(config_path) = &args.config { + std::fs::read_to_string(config_path) + .ok() + .and_then(|content| serde_yaml::from_str::(&content).ok()) + .and_then(|yaml| { + // Try pingora.proxy_certificates first, then fallback to root level + yaml.get("pingora") + .and_then(|pingora| pingora.get("proxy_certificates")) + .or_else(|| yaml.get("proxy_certificates")) + .and_then(|v| v.as_str().map(|s| s.to_string())) + }) + .unwrap_or_else(|| "/tmp/synapse-certs".to_string()) + } else { + "/tmp/synapse-certs".to_string() + }; - // Set proxy_certificates path for ACME certificate saving - crate::acme::set_proxy_certificates_path(Some(certificate_path.clone())); + // Set proxy_certificates path for ACME certificate saving + crate::acme::set_proxy_certificates_path(Some(certificate_path.clone())); - let refresh_interval = 30; // 30 seconds default refresh interval - let worker_config = worker::WorkerConfig { - name: "certificate".to_string(), - interval_secs: refresh_interval, - enabled: true, - }; + let refresh_interval = 30; // 30 seconds default refresh interval + let worker_config = worker::WorkerConfig { + name: "certificate".to_string(), + interval_secs: refresh_interval, + enabled: true, + }; - let upstreams_path = config.pingora.upstreams_conf.clone(); - let certificate_worker = worker::certificate::CertificateWorker::new( - certificate_path.clone(), - upstreams_path, - refresh_interval - ); + let upstreams_path = config.pingora.upstreams_conf.clone(); + let certificate_worker = worker::certificate::CertificateWorker::new( + certificate_path.clone(), + upstreams_path, + refresh_interval + ); - if let Err(e) = worker_manager.register_worker(worker_config, certificate_worker) { - log::error!("Failed to register certificate worker: {}", e); + if let Err(e) = worker_manager.register_worker(worker_config, certificate_worker) { + log::error!("Failed to register certificate worker: {}", e); + } } } @@ -885,10 +912,7 @@ async fn async_main(args: Args, config: Config) -> Result<()> { .filter(|value| !value.trim().is_empty()) .unwrap_or_else(|| gethostname::gethostname().to_string_lossy().into_owned()); - let agent_id = std::env::var("AGENT_ID") - .ok() - .filter(|value| !value.trim().is_empty()) - .unwrap_or_else(|| hostname.clone()); + let agent_id = build_agent_id(&hostname); let agent_name = std::env::var("AGENT_NAME") .ok() @@ -1046,6 +1070,7 @@ async fn async_main(args: Args, config: Config) -> Result<()> { // Start the old Pingora proxy system in a separate thread (non-blocking) // Only start if mode is "proxy" (disabled in agent mode) + #[cfg(feature = "proxy")] if config.mode == "proxy" { let bpf_stats_config = config.bpf_stats.clone(); let logging_config = config.logging.clone(); @@ -1159,8 +1184,8 @@ async fn async_main(args: Args, config: Config) -> Result<()> { "content_scanner": content_scanner_enabled, "redis": redis_initialized, "log_sender": log_sender_enabled, - "acme": config.acme.enabled && !is_agent_mode, - "proxy": config.mode != "agent", + "acme": cfg!(feature = "proxy") && config.acme.enabled && !is_agent_mode, + "proxy": cfg!(feature = "proxy") && config.mode != "agent", }, "api_configured": has_api_key, }); @@ -1236,6 +1261,42 @@ async fn async_main(args: Args, config: Config) -> Result<()> { std::process::exit(0); } +fn read_env_non_empty(name: &str) -> Option { + std::env::var(name) + .ok() + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) +} + +fn build_agent_id(hostname: &str) -> String { + if let Some(agent_id) = read_env_non_empty("AGENT_ID") { + return agent_id; + } + + let service = read_env_non_empty("AGENT_SERVICE"); + let instance = read_env_non_empty("AGENT_INSTANCE_ID") + .or_else(|| read_env_non_empty("AGENT_INSTANCE")); + + let mut parts = vec![hostname.to_string()]; + if let Some(service) = service { + parts.push(service); + } + if let Some(instance) = instance { + parts.push(instance); + } + + let derived = parts.join("-"); + if derived == hostname { + warn!( + "AGENT_ID not set; defaulting to hostname '{}'. For multiple agents on one host, \ +set a unique AGENT_ID (e.g., {}-agent-1) or provide AGENT_SERVICE/AGENT_INSTANCE_ID.", + hostname, hostname + ); + } + + derived +} + /// Start a background task that logs BPF statistics periodically fn start_bpf_stats_logging( collector: BpfStatsCollector, diff --git a/src/threat/mod.rs b/src/threat/mod.rs index 744393a..b1dfb2b 100644 --- a/src/threat/mod.rs +++ b/src/threat/mod.rs @@ -5,7 +5,12 @@ use chrono::{DateTime, Utc}; use maxminddb::{geoip2, MaxMindDbError, Reader}; use memmap2::MmapOptions; use std::fs::File; +#[cfg(feature = "proxy")] use pingora_memory_cache::MemoryCache; +#[cfg(not(feature = "proxy"))] +use dashmap::DashMap; +#[cfg(not(feature = "proxy"))] +use std::time::Instant; use serde::{Deserialize, Deserializer, Serialize}; use tokio::sync::{OnceCell, RwLock}; @@ -98,6 +103,97 @@ impl From<&ThreatResponse> for WafFields { } } +#[derive(Clone, Copy)] +struct CacheStatus { + hit: bool, +} + +impl CacheStatus { + fn is_hit(&self) -> bool { + self.hit + } +} + +#[cfg(feature = "proxy")] +struct ThreatCache { + inner: MemoryCache, +} + +#[cfg(feature = "proxy")] +impl ThreatCache { + fn new(max_entries: usize) -> Self { + Self { + inner: MemoryCache::new(max_entries), + } + } + + fn get(&self, key: &str) -> (Option, CacheStatus) { + let (value, status) = self.inner.get(key); + (value, CacheStatus { hit: status.is_hit() }) + } + + fn put(&self, key: &str, value: ThreatResponse, ttl: Option) { + self.inner.put(key, value, ttl); + } +} + +#[cfg(not(feature = "proxy"))] +struct ThreatCacheEntry { + value: ThreatResponse, + expires_at: Instant, +} + +#[cfg(not(feature = "proxy"))] +struct ThreatCache { + entries: DashMap, + max_entries: usize, +} + +#[cfg(not(feature = "proxy"))] +impl ThreatCache { + fn new(max_entries: usize) -> Self { + Self { + entries: DashMap::new(), + max_entries, + } + } + + fn get(&self, key: &str) -> (Option, CacheStatus) { + if let Some(entry) = self.entries.get(key) { + if Instant::now() < entry.expires_at { + return (Some(entry.value.clone()), CacheStatus { hit: true }); + } + } + self.entries.remove(key); + (None, CacheStatus { hit: false }) + } + + fn put(&self, key: &str, value: ThreatResponse, ttl: Option) { + let ttl = match ttl { + Some(t) if !t.is_zero() => t, + _ => return, + }; + + if self.entries.len() >= self.max_entries { + if let Some(oldest_key) = self.entries.iter().next().map(|entry| entry.key().clone()) { + self.entries.remove(&oldest_key); + } + } + + let expires_at = Instant::now() + .checked_add(ttl) + .unwrap_or_else(Instant::now); + + self.entries.insert( + key.to_string(), + ThreatCacheEntry { + value, + expires_at, + }, + ); + } +} + /// Threat intel client: Threat MMDB first, then GeoIP MMDB fallback, with in-memory cache pub struct ThreatClient { threat_mmdb_path: Option, @@ -108,7 +204,7 @@ pub struct ThreatClient { geoip_country_reader: RwLock>>>, geoip_asn_reader: RwLock>>>, geoip_city_reader: RwLock>>>, - pingora_cache: Arc>, + cache: Arc, } /// Default cache size for threat response cache (10,000 entries) @@ -152,7 +248,7 @@ impl ThreatClient { geoip_country_reader: RwLock::new(None), geoip_asn_reader: RwLock::new(None), geoip_city_reader: RwLock::new(None), - pingora_cache: Arc::new(MemoryCache::new(cache_size)), + cache: Arc::new(ThreatCache::new(cache_size)), } } @@ -178,10 +274,10 @@ impl ThreatClient { /// Priority: Cache → Threat MMDB → GeoIP MMDB (REST API disabled) pub async fn get_threat_intel(&self, ip: &str) -> Result> { // L1 cache - let (cached, status) = self.pingora_cache.get(ip); + let (cached, status) = self.cache.get(ip); if let Some(data) = cached { if status.is_hit() { - log::debug!("Threat data for {} found in pingora-memory-cache", ip); + log::debug!("Threat data for {} found in cache", ip); return Ok(Some(data)); } } @@ -198,7 +294,7 @@ impl ThreatClient { log::info!("🔍 [Threat] Checking Threat MMDB for {}", ip); if let Some(threat_data) = self.lookup_threat_mmdb(ip, ip_addr).await? { log::info!("🔍 [Threat] Found threat data in Threat MMDB for {}: score={}", ip, threat_data.intel.score); - self.set_pingora_cache(ip, &threat_data).await; + self.set_cache(ip, &threat_data).await; return Ok(Some(threat_data)); } @@ -208,7 +304,7 @@ impl ThreatClient { // GeoIP fallback let (geo, asn, org) = self.lookup_geo(ip_addr).await?; let response = build_no_data_response(ip, ip_addr, geo, asn, org); - self.set_pingora_cache(ip, &response).await; + self.set_cache(ip, &response).await; Ok(Some(response)) } @@ -590,10 +686,10 @@ impl ThreatClient { )) } - /// Set data in pingora-memory-cache with TTL from record - async fn set_pingora_cache(&self, ip: &str, data: &ThreatResponse) { + /// Set data in the threat cache with TTL from record + async fn set_cache(&self, ip: &str, data: &ThreatResponse) { let ttl = Duration::from_secs(data.ttl_s); - self.pingora_cache.put(ip, data.clone(), Some(ttl)); + self.cache.put(ip, data.clone(), Some(ttl)); } } diff --git a/src/utils.rs b/src/utils.rs index b77e967..b0d5d95 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -3,17 +3,26 @@ pub mod bpf_utils; #[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] #[path = "utils/bpf_utils_noop.rs"] pub mod bpf_utils; +#[cfg(feature = "proxy")] pub mod discovery; +#[cfg(feature = "proxy")] mod filewatch; +#[cfg(feature = "proxy")] pub mod healthcheck; pub mod http_utils; +#[cfg(feature = "proxy")] pub mod metrics; +#[cfg(feature = "proxy")] pub mod parceyaml; +#[cfg(feature = "proxy")] pub mod state; pub mod structs; +#[cfg(feature = "proxy")] pub mod tls; +#[cfg(feature = "proxy")] pub mod tls_client_hello; pub mod tls_fingerprint; +#[cfg(feature = "proxy")] pub mod tools; #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] pub mod tcp_fingerprint; diff --git a/src/utils/structs.rs b/src/utils/structs.rs index 5b8d5b7..3a2b80f 100644 --- a/src/utils/structs.rs +++ b/src/utils/structs.rs @@ -93,6 +93,7 @@ pub struct HostConfig { pub rate_limit: Option, #[serde(default)] pub certificate: Option, + #[cfg(feature = "proxy")] #[serde(default)] pub acme: Option, } @@ -103,7 +104,14 @@ impl HostConfig { /// If ssl_enabled is true but no ACME config exists, the user is expected to provide certificates manually. pub fn needs_certificate(&self) -> bool { // Only request certificates if ACME is explicitly configured - self.acme.is_some() + #[cfg(feature = "proxy")] + { + self.acme.is_some() + } + #[cfg(not(feature = "proxy"))] + { + false + } } } diff --git a/src/waf/wirefilter.rs b/src/waf/wirefilter.rs index 8817b79..edd8ad0 100644 --- a/src/waf/wirefilter.rs +++ b/src/waf/wirefilter.rs @@ -861,6 +861,7 @@ pub async fn load_waf_rules(waf_rules: Vec) -> a /// Evaluate WAF rules for a Pingora request /// This is a convenience function that converts Pingora's RequestHeader to hyper's Parts +#[cfg(feature = "proxy")] pub async fn evaluate_waf_for_pingora_request( req_header: &pingora_http::RequestHeader, body_bytes: &[u8], diff --git a/src/worker/mod.rs b/src/worker/mod.rs index def7bc9..01def73 100644 --- a/src/worker/mod.rs +++ b/src/worker/mod.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "proxy")] pub mod certificate; pub mod config; pub mod geoip_mmdb; From ca817957cadb406d5437bdad1bc3bb628f2f7a16 Mon Sep 17 00:00:00 2001 From: krichard1212 <136473183+krichard1212@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:09:39 +0100 Subject: [PATCH 06/14] build fix --- Cargo.toml | 1 + build.rs | 8 ++--- src/main.rs | 5 ++- src/threat/mod.rs | 87 ++++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 90 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 39a8a38..66fadfc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ vmlinux = { git = "https://github.com/libbpf/vmlinux.h.git", rev = "83a228cf37fc [dependencies] tokio = { version = "1", features = [ + "fs", "rt-multi-thread", "macros", "net", diff --git a/build.rs b/build.rs index ea50476..048b9e8 100644 --- a/build.rs +++ b/build.rs @@ -1,11 +1,6 @@ // build.rs use std::env; -#[cfg(unix)] -use std::ffi::OsStr; -#[cfg(unix)] -use std::path::{Path, PathBuf}; - #[cfg(all(unix, feature = "bpf"))] use libbpf_cargo::SkeletonBuilder; @@ -37,6 +32,9 @@ fn main() { #[cfg(all(unix, feature = "bpf"))] { + use std::ffi::OsStr; + use std::path::{Path, PathBuf}; + let arch = env::var("CARGO_CFG_TARGET_ARCH").expect("CARGO_CFG_TARGET_ARCH must be set"); let vmlinux_include = vmlinux::include_path_root().join(arch); diff --git a/src/main.rs b/src/main.rs index 649631d..fadca13 100644 --- a/src/main.rs +++ b/src/main.rs @@ -61,7 +61,9 @@ pub mod acme; use tokio::signal; use tokio::sync::watch; -use log::{error, info, warn}; +#[cfg(feature = "proxy")] +use log::{error, info}; +use log::warn; use crate::app_state::AppState; use crate::bpf_stats::BpfStatsCollector; @@ -130,6 +132,7 @@ fn main() -> Result<()> { #[cfg(not(feature = "proxy"))] { + let _ = certificate_name; return Err(anyhow::anyhow!( "clear-certificate is not available in agent-only builds" )); diff --git a/src/threat/mod.rs b/src/threat/mod.rs index b1dfb2b..11d0d2f 100644 --- a/src/threat/mod.rs +++ b/src/threat/mod.rs @@ -104,12 +104,12 @@ impl From<&ThreatResponse> for WafFields { } #[derive(Clone, Copy)] -struct CacheStatus { +pub struct CacheStatus { hit: bool, } impl CacheStatus { - fn is_hit(&self) -> bool { + pub fn is_hit(&self) -> bool { self.hit } } @@ -194,6 +194,83 @@ impl ThreatCache { } } +#[cfg(feature = "proxy")] +pub struct VersionCache { + inner: MemoryCache, +} + +#[cfg(feature = "proxy")] +impl VersionCache { + fn new(max_entries: usize) -> Self { + Self { + inner: MemoryCache::new(max_entries), + } + } + + pub fn get(&self, key: &str) -> (Option, CacheStatus) { + let (value, status) = self.inner.get(key); + (value, CacheStatus { hit: status.is_hit() }) + } + + pub fn put(&self, key: &str, value: String, ttl: Option) { + self.inner.put(key, value, ttl); + } +} + +#[cfg(not(feature = "proxy"))] +struct VersionCacheEntry { + value: String, + expires_at: Option, +} + +#[cfg(not(feature = "proxy"))] +pub struct VersionCache { + entries: DashMap, + max_entries: usize, +} + +#[cfg(not(feature = "proxy"))] +impl VersionCache { + fn new(max_entries: usize) -> Self { + Self { + entries: DashMap::new(), + max_entries, + } + } + + pub fn get(&self, key: &str) -> (Option, CacheStatus) { + if let Some(entry) = self.entries.get(key) { + if let Some(expires_at) = entry.expires_at { + if Instant::now() >= expires_at { + self.entries.remove(key); + return (None, CacheStatus { hit: false }); + } + } + return (Some(entry.value.clone()), CacheStatus { hit: true }); + } + (None, CacheStatus { hit: false }) + } + + pub fn put(&self, key: &str, value: String, ttl: Option) { + let expires_at = match ttl { + Some(t) if !t.is_zero() => Instant::now().checked_add(t), + Some(_) => return, + None => None, + }; + + if self.entries.len() >= self.max_entries { + if let Some(oldest_key) = self.entries.iter().next().map(|entry| entry.key().clone()) { + self.entries.remove(&oldest_key); + } + } + + self.entries.insert( + key.to_string(), + VersionCacheEntry { value, expires_at }, + ); + } +} + /// Threat intel client: Threat MMDB first, then GeoIP MMDB fallback, with in-memory cache pub struct ThreatClient { threat_mmdb_path: Option, @@ -814,10 +891,10 @@ pub async fn get_threat_intel(ip: &str) -> Result> { /// Get access to a version cache (separate from threat response cache) /// Uses the same pingora-memory-cache pattern as the threat response cache -pub fn get_version_cache() -> Result>> { +pub fn get_version_cache() -> Result> { use std::sync::OnceLock; - static VERSION_CACHE: OnceLock>> = OnceLock::new(); - Ok(VERSION_CACHE.get_or_init(|| Arc::new(MemoryCache::new(100))).clone()) + static VERSION_CACHE: OnceLock> = OnceLock::new(); + Ok(VERSION_CACHE.get_or_init(|| Arc::new(VersionCache::new(100))).clone()) } /// Get WAF fields for an IP address From 885dfa3e7c2b0e39ca0f2294dc25e3b16765502b Mon Sep 17 00:00:00 2001 From: krichard1212 <136473183+krichard1212@users.noreply.github.com> Date: Fri, 30 Jan 2026 11:03:56 +0100 Subject: [PATCH 07/14] agent_id improvements --- docs/ENVIRONMNET_VARS.md | 10 ++++++++++ src/cli.rs | 13 +++++++++++++ src/main.rs | 37 ++++++++++++++----------------------- 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/docs/ENVIRONMNET_VARS.md b/docs/ENVIRONMNET_VARS.md index 2346d72..876e0ec 100644 --- a/docs/ENVIRONMNET_VARS.md +++ b/docs/ENVIRONMNET_VARS.md @@ -2,6 +2,16 @@ # Application mode export AX_MODE="proxy" +# Agent identity (status/heartbeat events) +# Use a unique AGENT_ID per instance when running multiple agents on the same host. +export AGENT_ID="hostA-agent-1" +export AGENT_NAME="edge firewall 1" +export AGENT_SERVICE="synapse" +export AGENT_INSTANCE_ID="1" +export AGENT_TAGS="edge,prod" +export AGENT_IPS="203.0.113.10,2001:db8::10" +export AGENT_HEARTBEAT_SECS="30" + # Redis configuration export AX_REDIS_URL="redis://127.0.0.1/0" export AX_REDIS_PREFIX="ax:synapse" diff --git a/src/cli.rs b/src/cli.rs index b05a258..f8ea421 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -145,6 +145,8 @@ fn default_ip_version() -> String { pub struct Gen0SecConfig { #[serde(default)] pub api_key: String, + #[serde(default)] + pub workspace_id: String, #[serde(default = "default_base_url")] pub base_url: String, /// Threat MMDB database configuration @@ -282,6 +284,7 @@ impl Config { }, platform: Gen0SecConfig { api_key: "".to_string(), + workspace_id: "".to_string(), base_url: "https://api.gen0sec.com/v1".to_string(), threat: GeoipDatabaseConfig::default(), log_sending_enabled: true, @@ -367,6 +370,9 @@ impl Config { if let Some(api_key) = &args.arxignis_api_key { self.platform.api_key = api_key.clone(); } + if !args.arxignis_workspace_id.is_empty() { + self.platform.workspace_id = args.arxignis_workspace_id.clone(); + } if !args.arxignis_base_url.is_empty() && args.arxignis_base_url != "https://api.gen0sec.com/v1" { self.platform.base_url = args.arxignis_base_url.clone(); } @@ -593,6 +599,9 @@ impl Config { if let Some(val) = get_env_arxignis("API_KEY") { self.platform.api_key = val; } + if let Some(val) = get_env_arxignis("WORKSPACE_ID") { + self.platform.workspace_id = val; + } if let Some(val) = get_env_arxignis("BASE_URL") { self.platform.base_url = val; } @@ -728,6 +737,10 @@ pub struct Args { #[arg(long, default_value = "https://api.gen0sec.com/v1")] pub arxignis_base_url: String, + /// Workspace ID for agent identity + #[arg(long, default_value = "")] + pub arxignis_workspace_id: String, + /// Enable sending access logs to arxignis server #[arg(long)] pub arxignis_log_sending_enabled: Option, diff --git a/src/main.rs b/src/main.rs index fadca13..983c2bf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ use libbpf_rs::skel::{OpenSkel, SkelBuilder}; #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] use nix::net::if_::if_nametoindex; use chrono::Utc; +use sha2::{Digest, Sha256}; #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; @@ -915,13 +916,14 @@ async fn async_main(args: Args, config: Config) -> Result<()> { .filter(|value| !value.trim().is_empty()) .unwrap_or_else(|| gethostname::gethostname().to_string_lossy().into_owned()); - let agent_id = build_agent_id(&hostname); - let agent_name = std::env::var("AGENT_NAME") .ok() .filter(|value| !value.trim().is_empty()) .unwrap_or_else(|| hostname.clone()); + let workspace_id = config.platform.workspace_id.clone(); + let agent_id = build_agent_id(&agent_name, &workspace_id); + let tags = std::env::var("AGENT_TAGS") .ok() .map(|value| { @@ -1271,33 +1273,22 @@ fn read_env_non_empty(name: &str) -> Option { .filter(|value| !value.is_empty()) } -fn build_agent_id(hostname: &str) -> String { - if let Some(agent_id) = read_env_non_empty("AGENT_ID") { - return agent_id; - } - - let service = read_env_non_empty("AGENT_SERVICE"); - let instance = read_env_non_empty("AGENT_INSTANCE_ID") - .or_else(|| read_env_non_empty("AGENT_INSTANCE")); - - let mut parts = vec![hostname.to_string()]; - if let Some(service) = service { - parts.push(service); - } - if let Some(instance) = instance { - parts.push(instance); +fn build_agent_id(agent_name: &str, workspace_id: &str) -> String { + if read_env_non_empty("AGENT_ID").is_some() { + warn!("AGENT_ID is ignored; agent_id is derived from agent_name + workspace_id."); } - let derived = parts.join("-"); - if derived == hostname { + if workspace_id.trim().is_empty() { warn!( - "AGENT_ID not set; defaulting to hostname '{}'. For multiple agents on one host, \ -set a unique AGENT_ID (e.g., {}-agent-1) or provide AGENT_SERVICE/AGENT_INSTANCE_ID.", - hostname, hostname + "WORKSPACE_ID not set; agent_id derived only from agent_name '{}'. Set WORKSPACE_ID \ +(or ARXIGNIS_WORKSPACE_ID) to avoid collisions across organizations.", + agent_name ); } - derived + let input = format!("{}:{}", workspace_id.trim(), agent_name.trim()); + let digest = Sha256::digest(input.as_bytes()); + format!("{:x}", digest) } /// Start a background task that logs BPF statistics periodically From 627f2a663535f8b17b8e97b3211126496b9fdc63 Mon Sep 17 00:00:00 2001 From: krichard1212 <136473183+krichard1212@users.noreply.github.com> Date: Fri, 30 Jan 2026 11:03:56 +0100 Subject: [PATCH 08/14] cleanup --- config/synapse.docker.yaml | 26 ----------------------- config/synapse.placeholder.yaml | 30 -------------------------- docs/ENVIRONMNET_VARS.md | 10 +++++++++ src/cli.rs | 13 ++++++++++++ src/main.rs | 37 +++++++++++++-------------------- 5 files changed, 37 insertions(+), 79 deletions(-) delete mode 100644 config/synapse.docker.yaml delete mode 100644 config/synapse.placeholder.yaml diff --git a/config/synapse.docker.yaml b/config/synapse.docker.yaml deleted file mode 100644 index baf7bb6..0000000 --- a/config/synapse.docker.yaml +++ /dev/null @@ -1,26 +0,0 @@ -# Synapse Docker config for local Arxignis -mode: "agent" - -network: - iface: "" - ifaces: [] - disable_xdp: true - ip_version: "both" - firewall_mode: "none" - -platform: - api_key: "local-test-key" - base_url: "http://host.docker.internal:8080/v1" - log_sending_enabled: true - include_response_body: false - max_body_size: 1048576 - -logging: - level: "info" - -pingora: - proxy_address_http: "0.0.0.0:8080" - proxy_address_tls: null - upstreams_conf: "/app/upstreams.yaml" - config_address: "0.0.0.0:9090" - config_api_enabled: false diff --git a/config/synapse.placeholder.yaml b/config/synapse.placeholder.yaml deleted file mode 100644 index d112d36..0000000 --- a/config/synapse.placeholder.yaml +++ /dev/null @@ -1,30 +0,0 @@ -# Synapse placeholder config for container usage -mode: "agent" - -network: - iface: "" - ifaces: [] - disable_xdp: true - ip_version: "both" - firewall_mode: "none" - -platform: - api_key: "REPLACE_ME" - base_url: "https://api.gen0sec.com/v1" - log_sending_enabled: true - include_response_body: false - max_body_size: 1048576 - -logging: - level: "info" - -redis: - url: "" - prefix: "g0s:synapse" - -pingora: - proxy_address_http: "0.0.0.0:8080" - proxy_address_tls: null - upstreams_conf: "/app/upstreams.yaml" - config_address: "0.0.0.0:9090" - config_api_enabled: false diff --git a/docs/ENVIRONMNET_VARS.md b/docs/ENVIRONMNET_VARS.md index 2346d72..876e0ec 100644 --- a/docs/ENVIRONMNET_VARS.md +++ b/docs/ENVIRONMNET_VARS.md @@ -2,6 +2,16 @@ # Application mode export AX_MODE="proxy" +# Agent identity (status/heartbeat events) +# Use a unique AGENT_ID per instance when running multiple agents on the same host. +export AGENT_ID="hostA-agent-1" +export AGENT_NAME="edge firewall 1" +export AGENT_SERVICE="synapse" +export AGENT_INSTANCE_ID="1" +export AGENT_TAGS="edge,prod" +export AGENT_IPS="203.0.113.10,2001:db8::10" +export AGENT_HEARTBEAT_SECS="30" + # Redis configuration export AX_REDIS_URL="redis://127.0.0.1/0" export AX_REDIS_PREFIX="ax:synapse" diff --git a/src/cli.rs b/src/cli.rs index b05a258..f8ea421 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -145,6 +145,8 @@ fn default_ip_version() -> String { pub struct Gen0SecConfig { #[serde(default)] pub api_key: String, + #[serde(default)] + pub workspace_id: String, #[serde(default = "default_base_url")] pub base_url: String, /// Threat MMDB database configuration @@ -282,6 +284,7 @@ impl Config { }, platform: Gen0SecConfig { api_key: "".to_string(), + workspace_id: "".to_string(), base_url: "https://api.gen0sec.com/v1".to_string(), threat: GeoipDatabaseConfig::default(), log_sending_enabled: true, @@ -367,6 +370,9 @@ impl Config { if let Some(api_key) = &args.arxignis_api_key { self.platform.api_key = api_key.clone(); } + if !args.arxignis_workspace_id.is_empty() { + self.platform.workspace_id = args.arxignis_workspace_id.clone(); + } if !args.arxignis_base_url.is_empty() && args.arxignis_base_url != "https://api.gen0sec.com/v1" { self.platform.base_url = args.arxignis_base_url.clone(); } @@ -593,6 +599,9 @@ impl Config { if let Some(val) = get_env_arxignis("API_KEY") { self.platform.api_key = val; } + if let Some(val) = get_env_arxignis("WORKSPACE_ID") { + self.platform.workspace_id = val; + } if let Some(val) = get_env_arxignis("BASE_URL") { self.platform.base_url = val; } @@ -728,6 +737,10 @@ pub struct Args { #[arg(long, default_value = "https://api.gen0sec.com/v1")] pub arxignis_base_url: String, + /// Workspace ID for agent identity + #[arg(long, default_value = "")] + pub arxignis_workspace_id: String, + /// Enable sending access logs to arxignis server #[arg(long)] pub arxignis_log_sending_enabled: Option, diff --git a/src/main.rs b/src/main.rs index fadca13..983c2bf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ use libbpf_rs::skel::{OpenSkel, SkelBuilder}; #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] use nix::net::if_::if_nametoindex; use chrono::Utc; +use sha2::{Digest, Sha256}; #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; @@ -915,13 +916,14 @@ async fn async_main(args: Args, config: Config) -> Result<()> { .filter(|value| !value.trim().is_empty()) .unwrap_or_else(|| gethostname::gethostname().to_string_lossy().into_owned()); - let agent_id = build_agent_id(&hostname); - let agent_name = std::env::var("AGENT_NAME") .ok() .filter(|value| !value.trim().is_empty()) .unwrap_or_else(|| hostname.clone()); + let workspace_id = config.platform.workspace_id.clone(); + let agent_id = build_agent_id(&agent_name, &workspace_id); + let tags = std::env::var("AGENT_TAGS") .ok() .map(|value| { @@ -1271,33 +1273,22 @@ fn read_env_non_empty(name: &str) -> Option { .filter(|value| !value.is_empty()) } -fn build_agent_id(hostname: &str) -> String { - if let Some(agent_id) = read_env_non_empty("AGENT_ID") { - return agent_id; - } - - let service = read_env_non_empty("AGENT_SERVICE"); - let instance = read_env_non_empty("AGENT_INSTANCE_ID") - .or_else(|| read_env_non_empty("AGENT_INSTANCE")); - - let mut parts = vec![hostname.to_string()]; - if let Some(service) = service { - parts.push(service); - } - if let Some(instance) = instance { - parts.push(instance); +fn build_agent_id(agent_name: &str, workspace_id: &str) -> String { + if read_env_non_empty("AGENT_ID").is_some() { + warn!("AGENT_ID is ignored; agent_id is derived from agent_name + workspace_id."); } - let derived = parts.join("-"); - if derived == hostname { + if workspace_id.trim().is_empty() { warn!( - "AGENT_ID not set; defaulting to hostname '{}'. For multiple agents on one host, \ -set a unique AGENT_ID (e.g., {}-agent-1) or provide AGENT_SERVICE/AGENT_INSTANCE_ID.", - hostname, hostname + "WORKSPACE_ID not set; agent_id derived only from agent_name '{}'. Set WORKSPACE_ID \ +(or ARXIGNIS_WORKSPACE_ID) to avoid collisions across organizations.", + agent_name ); } - derived + let input = format!("{}:{}", workspace_id.trim(), agent_name.trim()); + let digest = Sha256::digest(input.as_bytes()); + format!("{:x}", digest) } /// Start a background task that logs BPF statistics periodically From ee3f45bde82e57b70b8d362041e214bb9e5f9cc6 Mon Sep 17 00:00:00 2001 From: krichard1212 <136473183+krichard1212@users.noreply.github.com> Date: Tue, 3 Feb 2026 09:55:07 +0100 Subject: [PATCH 09/14] support new api format Signed-off-by: krichard1212 <136473183+krichard1212@users.noreply.github.com> --- src/access_log.rs | 2490 +++++++++++------------ src/access_rules.rs | 1438 ++++++------- src/acme/config.rs | 526 ++--- src/acme/domain_reader.rs | 1240 +++++------ src/acme/embedded.rs | 836 ++++---- src/acme/errors.rs | 8 +- src/acme/lib.rs | 3164 ++++++++++++++--------------- src/acme/mod.rs | 38 +- src/acme/storage/mod.rs | 272 +-- src/acme/storage/redis.rs | 1378 ++++++------- src/acme/upstreams_reader.rs | 280 +-- src/actions/captcha.rs | 2022 +++++++++--------- src/actions/mod.rs | 2 +- src/agent_status.rs | 138 +- src/app_state.rs | 32 +- src/authcheck.rs | 116 +- src/bpf/filter.bpf.c | 1532 +++++++------- src/bpf/filter.h | 22 +- src/bpf_stats.rs | 1458 ++++++------- src/bpf_stats_noop.rs | 472 ++--- src/bpf_stub.rs | 58 +- src/captcha_server.rs | 306 +-- src/cli.rs | 2466 +++++++++++----------- src/content_scanning/mod.rs | 1304 ++++++------ src/firewall/iptables.rs | 612 +++--- src/firewall/mod.rs | 596 +++--- src/firewall/nftables.rs | 810 ++++---- src/firewall_noop.rs | 580 +++--- src/http_client.rs | 310 +-- src/http_proxy.rs | 10 +- src/http_proxy/bgservice.rs | 510 ++--- src/http_proxy/gethosts.rs | 312 +-- src/http_proxy/proxyhttp.rs | 2052 +++++++++---------- src/http_proxy/start.rs | 570 +++--- src/http_proxy/webserver.rs | 226 +-- src/ja4_plus.rs | 1446 ++++++------- src/main.rs | 2796 ++++++++++++------------- src/proxy_protocol.rs | 1196 +++++------ src/threat/mod.rs | 2410 +++++++++++----------- src/utils.rs | 62 +- src/utils/bpf_utils.rs | 450 ++-- src/utils/bpf_utils_noop.rs | 32 +- src/utils/discovery.rs | 96 +- src/utils/filewatch.rs | 118 +- src/utils/healthcheck.rs | 322 +-- src/utils/http_utils.rs | 172 +- src/utils/metrics.rs | 166 +- src/utils/parceyaml.rs | 632 +++--- src/utils/state.rs | 38 +- src/utils/structs.rs | 374 ++-- src/utils/tcp_fingerprint.rs | 2236 ++++++++++---------- src/utils/tcp_fingerprint_noop.rs | 332 +-- src/utils/tls.rs | 1304 ++++++------ src/utils/tls_client_hello.rs | 448 ++-- src/utils/tls_fingerprint.rs | 672 +++--- src/utils/tools.rs | 556 ++--- src/waf/actions/captcha.rs | 2022 +++++++++--------- src/waf/actions/mod.rs | 4 +- src/waf/mod.rs | 4 +- src/waf/wirefilter.rs | 2220 ++++++++++---------- src/worker/agent_status.rs | 114 +- src/worker/certificate.rs | 2316 ++++++++++----------- src/worker/config.rs | 1096 +++++----- src/worker/geoip_mmdb.rs | 658 +++--- src/worker/log.rs | 826 ++++---- src/worker/manager.rs | 198 +- src/worker/mod.rs | 20 +- src/worker/threat_mmdb.rs | 666 +++--- 68 files changed, 27094 insertions(+), 27094 deletions(-) diff --git a/src/access_log.rs b/src/access_log.rs index b632a08..0661288 100644 --- a/src/access_log.rs +++ b/src/access_log.rs @@ -1,1245 +1,1245 @@ -use std::collections::HashMap; -use std::net::SocketAddr; -use std::time::{SystemTime, UNIX_EPOCH}; - -use chrono::{DateTime, Utc}; -use hyper::{Response, header::HeaderValue}; -use http_body_util::{BodyExt, Full}; -use serde::{Deserialize, Serialize}; -use sha2::{Digest, Sha256}; - -use crate::ja4_plus::{Ja4hFingerprint, Ja4tFingerprint}; -use crate::utils::tcp_fingerprint::TcpFingerprintData; -use crate::worker::log::{get_log_sender_config, send_event, UnifiedEvent}; - -// Re-export for compatibility -pub use crate::worker::log::LogSenderConfig; - -/// Server certificate information for access logging -#[derive(Debug, Clone)] -pub struct ServerCertInfo { - pub issuer: String, - pub subject: String, - pub not_before: String, // RFC3339 format - pub not_after: String, // RFC3339 format - pub fingerprint_sha256: String, -} - -/// Lightweight access log summary for returning with responses -/// -/// # Usage Example -/// -/// ```no_run -/// use synapse::access_log::{AccessLogSummary, UpstreamInfo, PerformanceInfo}; -/// use chrono::Utc; -/// -/// // Create a summary with upstream and performance info -/// let summary = AccessLogSummary { -/// request_id: "req_123".to_string(), -/// timestamp: Utc::now(), -/// upstream: Some(UpstreamInfo { -/// selected: "backend1.example.com".to_string(), -/// method: "round_robin".to_string(), -/// reason: "healthy".to_string(), -/// }), -/// waf: None, -/// threat: None, -/// network: synapse::access_log::NetworkSummary { -/// src_ip: "1.2.3.4".to_string(), -/// dst_ip: "10.0.0.1".to_string(), -/// protocol: "https".to_string(), -/// }, -/// performance: PerformanceInfo { -/// request_time_ms: Some(150), -/// upstream_time_ms: Some(120), -/// }, -/// }; -/// -/// // Add to response headers -/// // summary.add_to_response_headers(&mut response); -/// -/// // Or get as JSON -/// let json = summary.to_json().unwrap(); -/// ``` -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AccessLogSummary { - pub request_id: String, - pub timestamp: DateTime, - pub upstream: Option, - pub waf: Option, - pub threat: Option, - pub network: NetworkSummary, - pub performance: PerformanceInfo, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UpstreamInfo { - pub selected: String, - pub method: String, - pub reason: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct WafInfo { - pub action: String, - pub rule_id: String, - pub rule_name: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ThreatInfo { - pub score: u32, - pub confidence: f64, - pub categories: Vec, - pub reason: String, - pub country: Option, - pub asn: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct NetworkSummary { - pub src_ip: String, - pub dst_ip: String, - pub protocol: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PerformanceInfo { - pub request_time_ms: Option, - pub upstream_time_ms: Option, -} - -impl AccessLogSummary { - /// Convert to JSON string - pub fn to_json(&self) -> Result { - serde_json::to_string(self) - } - - /// Convert to compact JSON for headers (excludes null fields) - pub fn to_compact_json(&self) -> String { - let mut parts = vec![format!(r#""request_id":"{}""#, self.request_id)]; - - if let Some(upstream) = &self.upstream { - parts.push(format!(r#""upstream":"{}""#, upstream.selected)); - parts.push(format!(r#""upstream_method":"{}""#, upstream.method)); - } - - if let Some(waf) = &self.waf { - parts.push(format!(r#""waf_action":"{}""#, waf.action)); - parts.push(format!(r#""waf_rule":"{}""#, waf.rule_name)); - } - - if let Some(threat) = &self.threat { - parts.push(format!(r#""threat_score":{}"#, threat.score)); - parts.push(format!(r#""threat_confidence":{:.2}"#, threat.confidence)); - } - - if let Some(ms) = self.performance.request_time_ms { - parts.push(format!(r#""request_time_ms":{}"#, ms)); - } - - format!("{{{}}}", parts.join(",")) - } - - /// Add as response headers - pub fn add_to_response_headers(&self, response: &mut Response>) { - let headers = response.headers_mut(); - - // Add request ID header - if let Ok(value) = HeaderValue::from_str(&self.request_id) { - headers.insert("X-Request-ID", value); - } - - // Add upstream info - if let Some(upstream) = &self.upstream { - if let Ok(value) = HeaderValue::from_str(&upstream.selected) { - headers.insert("X-Upstream-Server", value); - } - if let Ok(value) = HeaderValue::from_str(&upstream.method) { - headers.insert("X-Upstream-Method", value); - } - } - - // Add WAF info - if let Some(waf) = &self.waf { - if let Ok(value) = HeaderValue::from_str(&waf.action) { - headers.insert("X-WAF-Action", value); - } - if let Ok(value) = HeaderValue::from_str(&waf.rule_id) { - headers.insert("X-WAF-Rule-ID", value); - } - } - - // Add threat info - if let Some(threat) = &self.threat { - if let Ok(value) = HeaderValue::from_str(&threat.score.to_string()) { - headers.insert("X-Threat-Score", value); - } - if let Some(country) = &threat.country { - if let Ok(value) = HeaderValue::from_str(country) { - headers.insert("X-Client-Country", value); - } - } - } - - // Add performance metrics - if let Some(ms) = self.performance.request_time_ms { - if let Ok(value) = HeaderValue::from_str(&ms.to_string()) { - headers.insert("X-Request-Time-Ms", value); - } - } - - if let Some(ms) = self.performance.upstream_time_ms { - if let Ok(value) = HeaderValue::from_str(&ms.to_string()) { - headers.insert("X-Upstream-Time-Ms", value); - } - } - - // Add compact JSON summary - let compact = self.to_compact_json(); - if let Ok(value) = HeaderValue::from_str(&compact) { - headers.insert("X-Access-Log", value); - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct HttpAccessLog { - pub event_type: String, - pub schema_version: String, - pub timestamp: DateTime, - pub request_id: String, - pub http: HttpDetails, - pub network: NetworkDetails, - pub tls: Option, - pub response: ResponseDetails, - pub remediation: Option, - pub upstream: Option, - pub performance: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct HttpDetails { - pub method: String, - pub scheme: String, - pub host: String, - pub port: u16, - pub path: String, - pub query: String, - pub query_hash: Option, - pub headers: HashMap, - pub ja4h: Option, - pub user_agent: Option, - pub content_type: Option, - pub content_length: Option, - pub body: String, - pub body_sha256: String, - pub body_truncated: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct NetworkDetails { - pub src_ip: String, - pub src_port: u16, - pub dst_ip: String, - pub dst_port: u16, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TlsDetails { - pub version: String, - pub cipher: String, - pub alpn: Option, - pub sni: Option, - pub ja4: Option, - pub ja4_unsorted: Option, - pub ja4t: Option, - pub server_cert: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ServerCertDetails { - pub issuer: String, - pub subject: String, - pub not_before: DateTime, - pub not_after: DateTime, - pub fingerprint_sha256: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ResponseDetails { - pub status: u16, - pub status_text: String, - pub content_type: Option, - pub content_length: Option, - pub body: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RemediationDetails { - pub waf_action: Option, - pub waf_rule_id: Option, - pub waf_rule_name: Option, - pub threat_score: Option, - pub threat_confidence: Option, - pub threat_categories: Option>, - pub threat_tags: Option>, - pub threat_reason_code: Option, - pub threat_reason_summary: Option, - pub threat_advice: Option, - pub ip_country: Option, - pub ip_asn: Option, - pub ip_asn_org: Option, - pub ip_asn_country: Option, -} - -impl HttpAccessLog { - /// Create access log from request parts and response data - pub async fn create_from_parts( - req_parts: &hyper::http::request::Parts, - req_body_bytes: &bytes::Bytes, - peer_addr: SocketAddr, - dst_addr: SocketAddr, - tls_fingerprint: Option<&crate::ja4_plus::Ja4hFingerprint>, - tcp_fingerprint_data: Option<&TcpFingerprintData>, - server_cert_info: Option<&ServerCertInfo>, - response_data: ResponseData, - waf_result: Option<&crate::waf::wirefilter::WafResult>, - threat_data: Option<&crate::threat::ThreatResponse>, - upstream_info: Option, - performance_info: Option, - tls_sni: Option, - tls_alpn: Option, - tls_cipher: Option, - tls_ja4: Option, - tls_ja4_unsorted: Option, - ) -> Result<(), Box> { - let timestamp = Utc::now(); - let request_id = format!("req_{}", SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_nanos()); - - // Extract request details - let uri = &req_parts.uri; - let method = req_parts.method.to_string(); - - // Determine scheme: prefer URI scheme, fallback to TLS fingerprint presence, then default to http - let scheme = uri.scheme().map(|s| s.to_string()).unwrap_or_else(|| { - if tls_fingerprint.is_some() { - "https".to_string() - } else { - "http".to_string() - } - }); - - // Extract host from URI, fallback to Host header if URI doesn't have host - let host = uri.host().map(|h| h.to_string()).unwrap_or_else(|| { - req_parts.headers - .get("host") - .and_then(|h| h.to_str().ok()) - .map(|h| h.split(':').next().unwrap_or(h).to_string()) - .unwrap_or_else(|| "unknown".to_string()) - }); - - // Determine port: prefer URI port, fallback to scheme-based default - let port = uri.port_u16().unwrap_or(if scheme == "https" { 443 } else { 80 }); - let path = uri.path().to_string(); - let query = uri.query().unwrap_or("").to_string(); - - // Process headers - let mut headers = HashMap::new(); - let mut user_agent = None; - let mut content_type = None; - - for (name, value) in req_parts.headers.iter() { - let key = name.to_string(); - let val = value.to_str().unwrap_or("").to_string(); - headers.insert(key, val.clone()); - - if name.as_str().to_lowercase() == "user-agent" { - user_agent = Some(val.clone()); - } - if name.as_str().to_lowercase() == "content-type" { - content_type = Some(val); - } - } - - // Generate JA4H fingerprint - let ja4h_fp = Ja4hFingerprint::from_http_request( - req_parts.method.as_str(), - &format!("{:?}", req_parts.version), - &req_parts.headers - ); - - // Get log sender configuration for body processing - let log_config = { - let config_store = get_log_sender_config(); - let config_guard = config_store.read().unwrap(); - config_guard.as_ref().cloned() - }; - - // Process request body with truncation - respect include_request_body setting - let (body_str, body_sha256, body_truncated) = if let Some(config) = &log_config { - if !config.include_request_body { - // Request body logging disabled - ("".to_string(), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string(), false) - } else { - let max_body_size = config.max_body_size; - let truncated = req_body_bytes.len() > max_body_size; - let truncated_body_bytes = if truncated { - req_body_bytes.slice(..max_body_size) - } else { - req_body_bytes.clone() - }; - let body = String::from_utf8_lossy(&truncated_body_bytes).to_string(); - - // Calculate SHA256 hash - handle empty body explicitly - let hash = if req_body_bytes.is_empty() { - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string() - } else { - format!("{:x}", Sha256::digest(req_body_bytes)) - }; - - (body, hash, truncated) - } - } else { - // No config, default to disabled - ("".to_string(), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string(), false) - }; - - // Generate JA4T from TCP fingerprint data if available - let ja4t = tcp_fingerprint_data.map(|tcp_data| { - let ja4t_fp = Ja4tFingerprint::from_tcp_data( - tcp_data.window_size, - tcp_data.ttl, - tcp_data.mss, - tcp_data.window_scale, - &tcp_data.options, - ); - ja4t_fp.fingerprint - }); - - // Process TLS details - let tls_details = if let Some(fp) = tls_fingerprint { - // Use actual TLS version from fingerprint if available, otherwise infer from HTTP version - let tls_version = if scheme == "https" { - // Check if version looks like TLS version (e.g., "TLS 1.2", "TLS 1.3") - if fp.version.starts_with("TLS") { - fp.version.clone() - } else { - // Otherwise infer from HTTP version - match fp.version.as_str() { - "2.0" | "2" => "TLS 1.2".to_string(), // HTTP/2 typically uses TLS 1.2+ - "3.0" | "3" => "TLS 1.3".to_string(), // HTTP/3 uses TLS 1.3 - _ => "TLS 1.2".to_string(), // Default for HTTPS - } - } - } else { - "".to_string() // No TLS for HTTP - }; - - // Determine cipher - use provided cipher or infer from TLS version - let cipher = if let Some(ref provided_cipher) = tls_cipher { - provided_cipher.clone() - } else if scheme == "https" { - match fp.version.as_str() { - "3.0" | "3" => "TLS_AES_256_GCM_SHA384".to_string(), // HTTP/3 uses TLS 1.3 - "2.0" | "2" => "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384".to_string(), // HTTP/2 typically uses TLS 1.2 - _ => "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384".to_string(), // Default TLS 1.2 cipher - } - } else { - "".to_string() // No cipher for HTTP - }; - - // Extract server certificate details if available - let server_cert = extract_server_cert_details(server_cert_info); - - // Extract JA4 from TLS fingerprint data - use tls_ja4 if available - let ja4_value = tls_ja4.clone(); - - Some(TlsDetails { - version: tls_version, - cipher, - alpn: tls_alpn.clone(), - sni: tls_sni.clone(), - ja4: ja4_value, - ja4_unsorted: tls_ja4_unsorted.clone(), - ja4t: ja4t.clone(), - server_cert, - }) - } else if scheme == "https" { - // Create minimal TLS details for HTTPS connections without fingerprint (e.g., PROXY protocol) - let server_cert = extract_server_cert_details(server_cert_info); - - Some(TlsDetails { - version: "TLS 1.3".to_string(), - cipher: "TLS_AES_256_GCM_SHA384".to_string(), - alpn: None, - sni: None, - ja4: Some("t13d".to_string()), - ja4_unsorted: Some("t13d".to_string()), - ja4t: ja4t.clone(), - server_cert, - }) - } else { - None - }; - - // Create HTTP details - let http_details = HttpDetails { - method, - scheme, - host, - port, - path, - query: query.clone(), - query_hash: if query.is_empty() { None } else { Some(format!("{:x}", Sha256::digest(query.as_bytes()))) }, - headers, - ja4h: Some(ja4h_fp.fingerprint.clone()), - user_agent, - content_type, - content_length: Some(req_body_bytes.len() as u64), - body: body_str, - body_sha256, - body_truncated, - }; - - // Create network details - let network_details = NetworkDetails { - src_ip: peer_addr.ip().to_string(), - src_port: peer_addr.port(), - dst_ip: dst_addr.ip().to_string(), - dst_port: dst_addr.port(), - }; - - // Create response details from response_data - response body logging disabled - let response_details = ResponseDetails { - status: response_data.response_json["status"].as_u64().unwrap_or(0) as u16, - status_text: response_data.response_json["status_text"].as_str().unwrap_or("Unknown").to_string(), - content_type: response_data.response_json["content_type"].as_str().map(|s| s.to_string()), - content_length: response_data.response_json["content_length"].as_u64(), - body: "".to_string(), - }; - - // Create remediation details - let remediation_details = Self::create_remediation_details(waf_result, threat_data); - - // Create the access log - let access_log = HttpAccessLog { - event_type: "http_access_log".to_string(), - schema_version: "1.0.0".to_string(), - timestamp, - request_id, - http: http_details, - network: network_details, - tls: tls_details, - response: response_details, - remediation: remediation_details, - upstream: upstream_info, - performance: performance_info, - }; - - // Log to stdout (existing behavior) - if let Err(e) = access_log.log_to_stdout() { - log::warn!("Failed to log access log to stdout: {}", e); - } - - // Send to unified event queue - send_event(UnifiedEvent::HttpAccessLog(access_log)); - - Ok(()) - } - - /// Create remediation details from WAF result and threat intelligence data - fn create_remediation_details( - waf_result: Option<&crate::waf::wirefilter::WafResult>, - threat_data: Option<&crate::threat::ThreatResponse>, - ) -> Option { - // Check if WAF action requires remediation (Block/Challenge/RateLimit) - these will populate WAF fields - // RateLimit is included because it blocks requests when the limit is exceeded - let has_waf_remediation = match waf_result { - Some(waf) => matches!( - waf.action, - crate::waf::wirefilter::WafAction::Block - | crate::waf::wirefilter::WafAction::Challenge - | crate::waf::wirefilter::WafAction::RateLimit - ), - None => false, - }; - - // Check if threat data is meaningful (not just default/empty values) - let has_meaningful_threat_data = threat_data.map(|threat| { - threat.intel.score > 0 - || threat.intel.reason_code != "NO_DATA" - || !threat.intel.categories.is_empty() - || !threat.intel.tags.is_empty() - }).unwrap_or(false); - - // Create remediation section if: - // 1. WAF action is Block/Challenge/RateLimit (will populate WAF fields), OR - // 2. There's meaningful threat intelligence data (will populate threat fields) - // Note: WAF fields (waf_action, waf_rule_id, waf_rule_name) are populated for Block/Challenge/RateLimit - // RateLimit is included because it blocks requests when the limit is exceeded - // Allow actions don't populate WAF fields, but remediation section can still exist if there's meaningful threat data - if !has_waf_remediation && !has_meaningful_threat_data { - return None; - } - - let mut remediation = RemediationDetails { - waf_action: None, - waf_rule_id: None, - waf_rule_name: None, - threat_score: None, - threat_confidence: None, - threat_categories: None, - threat_tags: None, - threat_reason_code: None, - threat_reason_summary: None, - threat_advice: None, - ip_country: None, - ip_asn: None, - ip_asn_org: None, - ip_asn_country: None, - }; - - // Populate WAF data if available for actions that require remediation (Block/Challenge/RateLimit) - // RateLimit actions also populate WAF fields because they block/challenge requests when exceeded - // Allow actions don't populate WAF fields, but remediation section may still exist if there's meaningful threat data - if let Some(waf) = waf_result { - // Include WAF data in remediation for Block, Challenge, and RateLimit actions - // RateLimit is included because it blocks requests when the limit is exceeded - match waf.action { - crate::waf::wirefilter::WafAction::Block - | crate::waf::wirefilter::WafAction::Challenge - | crate::waf::wirefilter::WafAction::RateLimit => { - remediation.waf_action = Some(format!("{:?}", waf.action).to_lowercase()); - remediation.waf_rule_id = Some(waf.rule_id.clone()); - remediation.waf_rule_name = Some(waf.rule_name.clone()); - } - crate::waf::wirefilter::WafAction::Allow => { - // Allow actions don't populate WAF fields - // But remediation section may still exist if there's meaningful threat data - } - } - } - - // Populate threat intelligence data if available and meaningful - if let Some(threat) = threat_data { - // Always populate GeoIP fields (country, ASN) when available, regardless of threat score - // These are geographic/network identifiers that should always be included - let country_code = threat.context.geo.iso_code.clone(); - remediation.ip_country = Some(country_code); - remediation.ip_asn = Some(threat.context.asn); - remediation.ip_asn_org = Some(threat.context.org.clone()); - remediation.ip_asn_country = Some(threat.context.geo.asn_iso_code.clone()); - - // Only include threat-specific data if it's meaningful (not just default/empty values) - if has_meaningful_threat_data { - remediation.threat_score = Some(threat.intel.score); - remediation.threat_confidence = Some(threat.intel.confidence); - remediation.threat_categories = Some(threat.intel.categories.clone()); - remediation.threat_tags = Some(threat.intel.tags.clone()); - remediation.threat_reason_code = Some(threat.intel.reason_code.clone()); - remediation.threat_reason_summary = Some(threat.intel.reason_summary.clone()); - remediation.threat_advice = Some(threat.advice.clone()); - } - } - - // Only return remediation if it has any meaningful data (WAF fields for Block/Challenge, meaningful threat data, or GeoIP data) - let has_waf_data = remediation.waf_action.is_some(); - let has_threat_data = remediation.threat_score.is_some() || remediation.threat_reason_code.is_some(); - let has_geoip_data = remediation.ip_country.is_some() || remediation.ip_asn.is_some(); - - if has_waf_data || has_threat_data || has_geoip_data { - Some(remediation) - } else { - None - } - } - - pub fn to_json(&self) -> Result { - serde_json::to_string(self) - } - - pub fn log_to_stdout(&self) -> Result<(), Box> { - let json = self.to_json()?; - log::info!("{}", json); - Ok(()) - } - - /// Create a lightweight summary suitable for returning with responses - pub fn to_summary(&self) -> AccessLogSummary { - let waf_info = if let Some(remediation) = &self.remediation { - if let (Some(action), Some(rule_id), Some(rule_name)) = - (&remediation.waf_action, &remediation.waf_rule_id, &remediation.waf_rule_name) { - Some(WafInfo { - action: action.clone(), - rule_id: rule_id.clone(), - rule_name: rule_name.clone(), - }) - } else { - None - } - } else { - None - }; - - let threat_info = if let Some(remediation) = &self.remediation { - if let (Some(score), Some(confidence)) = - (remediation.threat_score, remediation.threat_confidence) { - Some(ThreatInfo { - score, - confidence, - categories: remediation.threat_categories.clone().unwrap_or_default(), - reason: remediation.threat_reason_summary.clone().unwrap_or_default(), - country: remediation.ip_country.clone(), - asn: remediation.ip_asn, - }) - } else { - None - } - } else { - None - }; - - let protocol = if self.tls.is_some() { - format!("{} over {}", self.http.scheme, - self.tls.as_ref().map(|t| t.version.as_str()).unwrap_or("TLS")) - } else { - self.http.scheme.clone() - }; - - AccessLogSummary { - request_id: self.request_id.clone(), - timestamp: self.timestamp, - upstream: self.upstream.clone(), - waf: waf_info, - threat: threat_info, - network: NetworkSummary { - src_ip: self.network.src_ip.clone(), - dst_ip: self.network.dst_ip.clone(), - protocol, - }, - performance: self.performance.clone().unwrap_or(PerformanceInfo { - request_time_ms: None, - upstream_time_ms: None, - }), - } - } - - /// Add upstream routing information to the access log - pub fn with_upstream(mut self, upstream: UpstreamInfo) -> Self { - self.upstream = Some(upstream); - self - } - - /// Add performance metrics to the access log - pub fn with_performance(mut self, performance: PerformanceInfo) -> Self { - self.performance = Some(performance); - self - } -} - -/// Helper struct to hold response data for access logging -#[derive(Debug, Clone)] -pub struct ResponseData { - pub response_json: serde_json::Value, - pub blocking_info: Option, - pub waf_result: Option, - pub threat_data: Option, -} - -impl ResponseData { - /// Create response data for a regular response - pub async fn from_response(response: Response>) -> Result> { - let (response_parts, response_body) = response.into_parts(); - let response_body_bytes = response_body.collect().await?.to_bytes(); - let response_body_str = String::from_utf8_lossy(&response_body_bytes).to_string(); - - let response_content_type = response_parts.headers - .get("content-type") - .and_then(|h| h.to_str().ok()) - .map(|s| s.to_string()); - - let response_json = serde_json::json!({ - "status": response_parts.status.as_u16(), - "status_text": response_parts.status.canonical_reason().unwrap_or("Unknown"), - "content_type": response_content_type, - "content_length": response_body_bytes.len() as u64, - "body": response_body_str - }); - - Ok(ResponseData { - response_json, - blocking_info: None, - waf_result: None, - threat_data: None, - }) - } - - /// Create response data for a blocked request - pub fn for_blocked_request( - block_reason: &str, - status_code: u16, - waf_result: Option, - threat_data: Option<&crate::threat::ThreatResponse>, - ) -> Self { - let status_text = match status_code { - 403 => "Forbidden", - 426 => "Upgrade Required", - 429 => "Too Many Requests", - _ => "Blocked" - }; - - let response_json = serde_json::json!({ - "status": status_code, - "status_text": status_text, - "content_type": "application/json", - "content_length": 0, - "body": format!("{{\"ok\":false,\"error\":\"{}\"}}", block_reason) - }); - - let blocking_info = serde_json::json!({ - "blocked": true, - "reason": block_reason, - "filter_type": "waf" - }); - - ResponseData { - response_json, - blocking_info: Some(blocking_info), - waf_result, - threat_data: threat_data.cloned(), - } - } - - /// Create response data for a malware-blocked request with scan details - pub fn for_malware_blocked_request( - signature: Option, - scan_error: Option, - waf_result: Option, - threat_data: Option<&crate::threat::ThreatResponse>, - ) -> Self { - let response_json = serde_json::json!({ - "status": 403, - "status_text": "Forbidden", - "content_type": "application/json", - "content_length": 0, - "body": "{\"ok\":false,\"error\":\"malware_detected\"}" - }); - - let mut blocking_info = serde_json::json!({ - "blocked": true, - "reason": "malware_detected", - "filter_type": "content_scanning", - "malware_detected": true, - }); - - if let Some(sig) = signature { - blocking_info["malware_signature"] = serde_json::Value::String(sig); - } - - if let Some(err) = scan_error { - blocking_info["scan_error"] = serde_json::Value::String(err); - } - - ResponseData { - response_json, - blocking_info: Some(blocking_info), - waf_result, - threat_data: threat_data.cloned(), - } - } -} - - -/// Extract server certificate details from server certificate info -fn extract_server_cert_details(server_cert_info: Option<&ServerCertInfo>) -> Option { - server_cert_info.map(|cert_info| { - // Parse the date strings from ServerCertInfo - let not_before = chrono::DateTime::parse_from_rfc3339(&cert_info.not_before) - .unwrap_or_else(|_| Utc::now().into()) - .with_timezone(&Utc); - let not_after = chrono::DateTime::parse_from_rfc3339(&cert_info.not_after) - .unwrap_or_else(|_| Utc::now().into()) - .with_timezone(&Utc); - - ServerCertDetails { - issuer: cert_info.issuer.clone(), - subject: cert_info.subject.clone(), - not_before, - not_after, - fingerprint_sha256: cert_info.fingerprint_sha256.clone(), - } - }) -} - - -#[cfg(test)] -mod tests { - use super::*; - use hyper::Request; - - #[tokio::test] - async fn test_access_log_creation() { - // Create a simple request - let _req = Request::builder() - .method("GET") - .uri("https://example.com/test?param=value") - .header("User-Agent", format!("TestAgent/{}", env!("CARGO_PKG_VERSION"))) - .body(Full::new(bytes::Bytes::new())) - .unwrap(); - - // Create a simple response - let _response = Response::builder() - .status(200) - .header("Content-Type", "application/json") - .body(Full::new(bytes::Bytes::from("{\"ok\":true}"))) - .unwrap(); - - let _peer: SocketAddr = "127.0.0.1:12345".parse().unwrap(); - let _dst_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); - - // This test would need more setup to work properly - // For now, just test the structure creation - let log = HttpAccessLog { - event_type: "http_access_log".to_string(), - schema_version: "1.0.0".to_string(), - timestamp: Utc::now(), - request_id: "test_123".to_string(), - http: HttpDetails { - method: "GET".to_string(), - scheme: "https".to_string(), - host: "example.com".to_string(), - port: 443, - path: "/test".to_string(), - query: "param=value".to_string(), - query_hash: Some("abc123".to_string()), - headers: HashMap::new(), - ja4h: Some("g11n_000000000000_000000000000".to_string()), - user_agent: Some(format!("TestAgent/{}", env!("CARGO_PKG_VERSION"))), - content_type: None, - content_length: None, - body: "".to_string(), - body_sha256: "abc123".to_string(), - body_truncated: false, - }, - network: NetworkDetails { - src_ip: "127.0.0.1".to_string(), - src_port: 12345, - dst_ip: "127.0.0.1".to_string(), - dst_port: 443, - }, - tls: None, - response: ResponseDetails { - status: 200, - status_text: "OK".to_string(), - content_type: Some("application/json".to_string()), - content_length: Some(10), - body: "{\"ok\":true}".to_string(), - }, - remediation: None, - upstream: Some(UpstreamInfo { - selected: "backend1".to_string(), - method: "round_robin".to_string(), - reason: "healthy".to_string(), - }), - performance: Some(PerformanceInfo { - request_time_ms: Some(50), - upstream_time_ms: Some(45), - }), - }; - - let json = log.to_json().unwrap(); - assert!(json.contains("http_access_log")); - assert!(json.contains("GET")); - assert!(json.contains("example.com")); - assert!(json.contains("backend1")); - - // Test summary creation - let summary = log.to_summary(); - assert_eq!(summary.request_id, "test_123"); - assert_eq!(summary.upstream.as_ref().unwrap().selected, "backend1"); - assert_eq!(summary.performance.request_time_ms, Some(50)); - } - - #[test] - fn test_remediation_with_threat_intelligence() { - use crate::waf::wirefilter::{WafAction, WafResult}; - use crate::threat::{ThreatResponse, ThreatIntel, ThreatContext, GeoInfo}; - - // Create a mock threat response - let threat_response = ThreatResponse { - schema_version: "1.0.0".to_string(), - tenant_id: "test-tenant".to_string(), - ip: "192.168.1.100".to_string(), - intel: ThreatIntel { - score: 85, - confidence: 0.95, - score_version: "1.0".to_string(), - categories: vec!["malware".to_string(), "botnet".to_string()], - tags: vec!["suspicious".to_string()], - first_seen: Some(Utc::now()), - last_seen: Some(Utc::now()), - source_count: 5, - reason_code: "THREAT_DETECTED".to_string(), - reason_summary: "IP address associated with malicious activity".to_string(), - rule_id: "rule-123".to_string(), - }, - context: ThreatContext { - asn: 12345, - org: "Test ISP".to_string(), - ip_version: 4, - geo: GeoInfo { - country: "United States".to_string(), - iso_code: "US".to_string(), - asn_iso_code: "US".to_string(), - }, - }, - advice: "Block this IP address".to_string(), - ttl_s: 3600, - generated_at: Utc::now(), - }; - - // Create a WAF result with threat intelligence - let waf_result = WafResult { - action: WafAction::Block, - rule_name: "Threat intelligence - Block".to_string(), - rule_id: "aa0880fd-4d3a-41a6-a02b-9b8b83ca615a".to_string(), - rate_limit_config: None, - threat_response: Some(threat_response.clone()), - }; - - // Test create_remediation_details with threat intelligence - let remediation = HttpAccessLog::create_remediation_details( - Some(&waf_result), - Some(&threat_response), - ); - - assert!(remediation.is_some()); - let remediation = remediation.unwrap(); - - // Verify WAF data - assert_eq!(remediation.waf_action, Some("block".to_string())); - assert_eq!(remediation.waf_rule_id, Some("aa0880fd-4d3a-41a6-a02b-9b8b83ca615a".to_string())); - assert_eq!(remediation.waf_rule_name, Some("Threat intelligence - Block".to_string())); - - // Verify threat intelligence data - assert_eq!(remediation.threat_score, Some(85)); - assert_eq!(remediation.threat_confidence, Some(0.95)); - assert_eq!(remediation.threat_categories, Some(vec!["malware".to_string(), "botnet".to_string()])); - assert_eq!(remediation.threat_tags, Some(vec!["suspicious".to_string()])); - assert_eq!(remediation.threat_reason_code, Some("THREAT_DETECTED".to_string())); - assert_eq!(remediation.threat_reason_summary, Some("IP address associated with malicious activity".to_string())); - assert_eq!(remediation.threat_advice, Some("Block this IP address".to_string())); - assert_eq!(remediation.ip_country, Some("US".to_string())); - assert_eq!(remediation.ip_asn, Some(12345)); - assert_eq!(remediation.ip_asn_org, Some("Test ISP".to_string())); - assert_eq!(remediation.ip_asn_country, Some("US".to_string())); - } - - #[test] - fn test_remediation_with_waf_challenge_and_threat_intelligence() { - use crate::waf::wirefilter::{WafAction, WafResult}; - use crate::threat::{ThreatResponse, ThreatIntel, ThreatContext, GeoInfo}; - - // Create a mock threat response - let threat_response = ThreatResponse { - schema_version: "1.0.0".to_string(), - tenant_id: "test-tenant".to_string(), - ip: "10.0.0.1".to_string(), - intel: ThreatIntel { - score: 60, - confidence: 0.75, - score_version: "1.0".to_string(), - categories: vec!["suspicious".to_string()], - tags: vec!["review".to_string()], - first_seen: Some(Utc::now()), - last_seen: Some(Utc::now()), - source_count: 2, - reason_code: "SUSPICIOUS_ACTIVITY".to_string(), - reason_summary: "Unusual traffic patterns detected".to_string(), - rule_id: "rule-456".to_string(), - }, - context: ThreatContext { - asn: 67890, - org: "Another ISP".to_string(), - ip_version: 4, - geo: GeoInfo { - country: "Canada".to_string(), - iso_code: "CA".to_string(), - asn_iso_code: "CA".to_string(), - }, - }, - advice: "Challenge with CAPTCHA".to_string(), - ttl_s: 1800, - generated_at: Utc::now(), - }; - - // Create a WAF result with challenge action and threat intelligence - let waf_result = WafResult { - action: WafAction::Challenge, - rule_name: "Threat intelligence - Challenge".to_string(), - rule_id: "1eb12716-6e13-4e23-a1d9-c879f6175317".to_string(), - rate_limit_config: None, - threat_response: Some(threat_response.clone()), - }; - - // Test create_remediation_details with challenge action and threat intelligence - let remediation = HttpAccessLog::create_remediation_details( - Some(&waf_result), - Some(&threat_response), - ); - - assert!(remediation.is_some()); - let remediation = remediation.unwrap(); - - // Verify WAF data (challenge should be included in remediation) - assert_eq!(remediation.waf_action, Some("challenge".to_string())); - assert_eq!(remediation.waf_rule_id, Some("1eb12716-6e13-4e23-a1d9-c879f6175317".to_string())); - assert_eq!(remediation.waf_rule_name, Some("Threat intelligence - Challenge".to_string())); - - // Verify threat intelligence data - assert_eq!(remediation.threat_score, Some(60)); - assert_eq!(remediation.threat_confidence, Some(0.75)); - assert_eq!(remediation.threat_categories, Some(vec!["suspicious".to_string()])); - assert_eq!(remediation.ip_country, Some("CA".to_string())); - assert_eq!(remediation.ip_asn, Some(67890)); - } - - #[test] - fn test_remediation_without_threat_intelligence() { - use crate::waf::wirefilter::{WafAction, WafResult}; - - // Create a WAF result without threat intelligence - let waf_result = WafResult { - action: WafAction::Block, - rule_name: "Custom Rule".to_string(), - rule_id: "custom-rule-123".to_string(), - rate_limit_config: None, - threat_response: None, - }; - - // Test create_remediation_details without threat intelligence - let remediation = HttpAccessLog::create_remediation_details( - Some(&waf_result), - None, - ); - - assert!(remediation.is_some()); - let remediation = remediation.unwrap(); - - // Verify WAF data is present - assert_eq!(remediation.waf_action, Some("block".to_string())); - assert_eq!(remediation.waf_rule_id, Some("custom-rule-123".to_string())); - - // Verify threat intelligence data is null - assert_eq!(remediation.threat_score, None); - assert_eq!(remediation.threat_confidence, None); - assert_eq!(remediation.threat_categories, None); - assert_eq!(remediation.ip_country, None); - assert_eq!(remediation.ip_asn, None); - } - - #[test] - fn test_remediation_json_serialization_with_threat_intelligence() { - use crate::waf::wirefilter::{WafAction, WafResult}; - use crate::threat::{ThreatResponse, ThreatIntel, ThreatContext, GeoInfo}; - - // Create a mock threat response - let threat_response = ThreatResponse { - schema_version: "1.0.0".to_string(), - tenant_id: "test-tenant".to_string(), - ip: "192.168.1.100".to_string(), - intel: ThreatIntel { - score: 90, - confidence: 0.98, - score_version: "1.0".to_string(), - categories: vec!["malware".to_string()], - tags: vec!["critical".to_string()], - first_seen: Some(Utc::now()), - last_seen: Some(Utc::now()), - source_count: 10, - reason_code: "MALWARE_DETECTED".to_string(), - reason_summary: "Known malware source".to_string(), - rule_id: "rule-789".to_string(), - }, - context: ThreatContext { - asn: 99999, - org: "Malicious Network".to_string(), - ip_version: 4, - geo: GeoInfo { - country: "Unknown".to_string(), - iso_code: "XX".to_string(), - asn_iso_code: "XX".to_string(), - }, - }, - advice: "Immediate block required".to_string(), - ttl_s: 7200, - generated_at: Utc::now(), - }; - - let waf_result = WafResult { - action: WafAction::Block, - rule_name: "Threat intelligence - Block".to_string(), - rule_id: "test-rule-id".to_string(), - rate_limit_config: None, - threat_response: Some(threat_response.clone()), - }; - - let remediation = HttpAccessLog::create_remediation_details( - Some(&waf_result), - Some(&threat_response), - ).unwrap(); - - // Create a full access log with remediation - let access_log = HttpAccessLog { - event_type: "http_access_log".to_string(), - schema_version: "1.0.0".to_string(), - timestamp: Utc::now(), - request_id: "test_req_123".to_string(), - http: HttpDetails { - method: "GET".to_string(), - scheme: "https".to_string(), - host: "example.com".to_string(), - port: 443, - path: "/".to_string(), - query: "".to_string(), - query_hash: None, - headers: HashMap::new(), - ja4h: None, - user_agent: None, - content_type: None, - content_length: None, - body: "".to_string(), - body_sha256: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string(), - body_truncated: false, - }, - network: NetworkDetails { - src_ip: "192.168.1.100".to_string(), - src_port: 12345, - dst_ip: "10.0.0.1".to_string(), - dst_port: 443, - }, - tls: None, - response: ResponseDetails { - status: 403, - status_text: "Forbidden".to_string(), - content_type: None, - content_length: None, - body: "".to_string(), - }, - remediation: Some(remediation), - upstream: None, - performance: None, - }; - - // Serialize to JSON and verify threat intelligence fields are present - let json = access_log.to_json().unwrap(); - assert!(json.contains("\"threat_score\":90")); - assert!(json.contains("\"threat_confidence\":0.98")); - assert!(json.contains("\"threat_categories\":[\"malware\"]")); - assert!(json.contains("\"ip_country\":\"XX\"")); - assert!(json.contains("\"ip_asn\":99999")); - assert!(json.contains("\"threat_reason_code\":\"MALWARE_DETECTED\"")); - assert!(json.contains("\"threat_reason_summary\":\"Known malware source\"")); - } -} +use std::collections::HashMap; +use std::net::SocketAddr; +use std::time::{SystemTime, UNIX_EPOCH}; + +use chrono::{DateTime, Utc}; +use hyper::{Response, header::HeaderValue}; +use http_body_util::{BodyExt, Full}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; + +use crate::ja4_plus::{Ja4hFingerprint, Ja4tFingerprint}; +use crate::utils::tcp_fingerprint::TcpFingerprintData; +use crate::worker::log::{get_log_sender_config, send_event, UnifiedEvent}; + +// Re-export for compatibility +pub use crate::worker::log::LogSenderConfig; + +/// Server certificate information for access logging +#[derive(Debug, Clone)] +pub struct ServerCertInfo { + pub issuer: String, + pub subject: String, + pub not_before: String, // RFC3339 format + pub not_after: String, // RFC3339 format + pub fingerprint_sha256: String, +} + +/// Lightweight access log summary for returning with responses +/// +/// # Usage Example +/// +/// ```no_run +/// use synapse::access_log::{AccessLogSummary, UpstreamInfo, PerformanceInfo}; +/// use chrono::Utc; +/// +/// // Create a summary with upstream and performance info +/// let summary = AccessLogSummary { +/// request_id: "req_123".to_string(), +/// timestamp: Utc::now(), +/// upstream: Some(UpstreamInfo { +/// selected: "backend1.example.com".to_string(), +/// method: "round_robin".to_string(), +/// reason: "healthy".to_string(), +/// }), +/// waf: None, +/// threat: None, +/// network: synapse::access_log::NetworkSummary { +/// src_ip: "1.2.3.4".to_string(), +/// dst_ip: "10.0.0.1".to_string(), +/// protocol: "https".to_string(), +/// }, +/// performance: PerformanceInfo { +/// request_time_ms: Some(150), +/// upstream_time_ms: Some(120), +/// }, +/// }; +/// +/// // Add to response headers +/// // summary.add_to_response_headers(&mut response); +/// +/// // Or get as JSON +/// let json = summary.to_json().unwrap(); +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessLogSummary { + pub request_id: String, + pub timestamp: DateTime, + pub upstream: Option, + pub waf: Option, + pub threat: Option, + pub network: NetworkSummary, + pub performance: PerformanceInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UpstreamInfo { + pub selected: String, + pub method: String, + pub reason: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WafInfo { + pub action: String, + pub rule_id: String, + pub rule_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreatInfo { + pub score: u32, + pub confidence: f64, + pub categories: Vec, + pub reason: String, + pub country: Option, + pub asn: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NetworkSummary { + pub src_ip: String, + pub dst_ip: String, + pub protocol: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PerformanceInfo { + pub request_time_ms: Option, + pub upstream_time_ms: Option, +} + +impl AccessLogSummary { + /// Convert to JSON string + pub fn to_json(&self) -> Result { + serde_json::to_string(self) + } + + /// Convert to compact JSON for headers (excludes null fields) + pub fn to_compact_json(&self) -> String { + let mut parts = vec![format!(r#""request_id":"{}""#, self.request_id)]; + + if let Some(upstream) = &self.upstream { + parts.push(format!(r#""upstream":"{}""#, upstream.selected)); + parts.push(format!(r#""upstream_method":"{}""#, upstream.method)); + } + + if let Some(waf) = &self.waf { + parts.push(format!(r#""waf_action":"{}""#, waf.action)); + parts.push(format!(r#""waf_rule":"{}""#, waf.rule_name)); + } + + if let Some(threat) = &self.threat { + parts.push(format!(r#""threat_score":{}"#, threat.score)); + parts.push(format!(r#""threat_confidence":{:.2}"#, threat.confidence)); + } + + if let Some(ms) = self.performance.request_time_ms { + parts.push(format!(r#""request_time_ms":{}"#, ms)); + } + + format!("{{{}}}", parts.join(",")) + } + + /// Add as response headers + pub fn add_to_response_headers(&self, response: &mut Response>) { + let headers = response.headers_mut(); + + // Add request ID header + if let Ok(value) = HeaderValue::from_str(&self.request_id) { + headers.insert("X-Request-ID", value); + } + + // Add upstream info + if let Some(upstream) = &self.upstream { + if let Ok(value) = HeaderValue::from_str(&upstream.selected) { + headers.insert("X-Upstream-Server", value); + } + if let Ok(value) = HeaderValue::from_str(&upstream.method) { + headers.insert("X-Upstream-Method", value); + } + } + + // Add WAF info + if let Some(waf) = &self.waf { + if let Ok(value) = HeaderValue::from_str(&waf.action) { + headers.insert("X-WAF-Action", value); + } + if let Ok(value) = HeaderValue::from_str(&waf.rule_id) { + headers.insert("X-WAF-Rule-ID", value); + } + } + + // Add threat info + if let Some(threat) = &self.threat { + if let Ok(value) = HeaderValue::from_str(&threat.score.to_string()) { + headers.insert("X-Threat-Score", value); + } + if let Some(country) = &threat.country { + if let Ok(value) = HeaderValue::from_str(country) { + headers.insert("X-Client-Country", value); + } + } + } + + // Add performance metrics + if let Some(ms) = self.performance.request_time_ms { + if let Ok(value) = HeaderValue::from_str(&ms.to_string()) { + headers.insert("X-Request-Time-Ms", value); + } + } + + if let Some(ms) = self.performance.upstream_time_ms { + if let Ok(value) = HeaderValue::from_str(&ms.to_string()) { + headers.insert("X-Upstream-Time-Ms", value); + } + } + + // Add compact JSON summary + let compact = self.to_compact_json(); + if let Ok(value) = HeaderValue::from_str(&compact) { + headers.insert("X-Access-Log", value); + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HttpAccessLog { + pub event_type: String, + pub schema_version: String, + pub timestamp: DateTime, + pub request_id: String, + pub http: HttpDetails, + pub network: NetworkDetails, + pub tls: Option, + pub response: ResponseDetails, + pub remediation: Option, + pub upstream: Option, + pub performance: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HttpDetails { + pub method: String, + pub scheme: String, + pub host: String, + pub port: u16, + pub path: String, + pub query: String, + pub query_hash: Option, + pub headers: HashMap, + pub ja4h: Option, + pub user_agent: Option, + pub content_type: Option, + pub content_length: Option, + pub body: String, + pub body_sha256: String, + pub body_truncated: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NetworkDetails { + pub src_ip: String, + pub src_port: u16, + pub dst_ip: String, + pub dst_port: u16, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TlsDetails { + pub version: String, + pub cipher: String, + pub alpn: Option, + pub sni: Option, + pub ja4: Option, + pub ja4_unsorted: Option, + pub ja4t: Option, + pub server_cert: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerCertDetails { + pub issuer: String, + pub subject: String, + pub not_before: DateTime, + pub not_after: DateTime, + pub fingerprint_sha256: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponseDetails { + pub status: u16, + pub status_text: String, + pub content_type: Option, + pub content_length: Option, + pub body: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RemediationDetails { + pub waf_action: Option, + pub waf_rule_id: Option, + pub waf_rule_name: Option, + pub threat_score: Option, + pub threat_confidence: Option, + pub threat_categories: Option>, + pub threat_tags: Option>, + pub threat_reason_code: Option, + pub threat_reason_summary: Option, + pub threat_advice: Option, + pub ip_country: Option, + pub ip_asn: Option, + pub ip_asn_org: Option, + pub ip_asn_country: Option, +} + +impl HttpAccessLog { + /// Create access log from request parts and response data + pub async fn create_from_parts( + req_parts: &hyper::http::request::Parts, + req_body_bytes: &bytes::Bytes, + peer_addr: SocketAddr, + dst_addr: SocketAddr, + tls_fingerprint: Option<&crate::ja4_plus::Ja4hFingerprint>, + tcp_fingerprint_data: Option<&TcpFingerprintData>, + server_cert_info: Option<&ServerCertInfo>, + response_data: ResponseData, + waf_result: Option<&crate::waf::wirefilter::WafResult>, + threat_data: Option<&crate::threat::ThreatResponse>, + upstream_info: Option, + performance_info: Option, + tls_sni: Option, + tls_alpn: Option, + tls_cipher: Option, + tls_ja4: Option, + tls_ja4_unsorted: Option, + ) -> Result<(), Box> { + let timestamp = Utc::now(); + let request_id = format!("req_{}", SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_nanos()); + + // Extract request details + let uri = &req_parts.uri; + let method = req_parts.method.to_string(); + + // Determine scheme: prefer URI scheme, fallback to TLS fingerprint presence, then default to http + let scheme = uri.scheme().map(|s| s.to_string()).unwrap_or_else(|| { + if tls_fingerprint.is_some() { + "https".to_string() + } else { + "http".to_string() + } + }); + + // Extract host from URI, fallback to Host header if URI doesn't have host + let host = uri.host().map(|h| h.to_string()).unwrap_or_else(|| { + req_parts.headers + .get("host") + .and_then(|h| h.to_str().ok()) + .map(|h| h.split(':').next().unwrap_or(h).to_string()) + .unwrap_or_else(|| "unknown".to_string()) + }); + + // Determine port: prefer URI port, fallback to scheme-based default + let port = uri.port_u16().unwrap_or(if scheme == "https" { 443 } else { 80 }); + let path = uri.path().to_string(); + let query = uri.query().unwrap_or("").to_string(); + + // Process headers + let mut headers = HashMap::new(); + let mut user_agent = None; + let mut content_type = None; + + for (name, value) in req_parts.headers.iter() { + let key = name.to_string(); + let val = value.to_str().unwrap_or("").to_string(); + headers.insert(key, val.clone()); + + if name.as_str().to_lowercase() == "user-agent" { + user_agent = Some(val.clone()); + } + if name.as_str().to_lowercase() == "content-type" { + content_type = Some(val); + } + } + + // Generate JA4H fingerprint + let ja4h_fp = Ja4hFingerprint::from_http_request( + req_parts.method.as_str(), + &format!("{:?}", req_parts.version), + &req_parts.headers + ); + + // Get log sender configuration for body processing + let log_config = { + let config_store = get_log_sender_config(); + let config_guard = config_store.read().unwrap(); + config_guard.as_ref().cloned() + }; + + // Process request body with truncation - respect include_request_body setting + let (body_str, body_sha256, body_truncated) = if let Some(config) = &log_config { + if !config.include_request_body { + // Request body logging disabled + ("".to_string(), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string(), false) + } else { + let max_body_size = config.max_body_size; + let truncated = req_body_bytes.len() > max_body_size; + let truncated_body_bytes = if truncated { + req_body_bytes.slice(..max_body_size) + } else { + req_body_bytes.clone() + }; + let body = String::from_utf8_lossy(&truncated_body_bytes).to_string(); + + // Calculate SHA256 hash - handle empty body explicitly + let hash = if req_body_bytes.is_empty() { + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string() + } else { + format!("{:x}", Sha256::digest(req_body_bytes)) + }; + + (body, hash, truncated) + } + } else { + // No config, default to disabled + ("".to_string(), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string(), false) + }; + + // Generate JA4T from TCP fingerprint data if available + let ja4t = tcp_fingerprint_data.map(|tcp_data| { + let ja4t_fp = Ja4tFingerprint::from_tcp_data( + tcp_data.window_size, + tcp_data.ttl, + tcp_data.mss, + tcp_data.window_scale, + &tcp_data.options, + ); + ja4t_fp.fingerprint + }); + + // Process TLS details + let tls_details = if let Some(fp) = tls_fingerprint { + // Use actual TLS version from fingerprint if available, otherwise infer from HTTP version + let tls_version = if scheme == "https" { + // Check if version looks like TLS version (e.g., "TLS 1.2", "TLS 1.3") + if fp.version.starts_with("TLS") { + fp.version.clone() + } else { + // Otherwise infer from HTTP version + match fp.version.as_str() { + "2.0" | "2" => "TLS 1.2".to_string(), // HTTP/2 typically uses TLS 1.2+ + "3.0" | "3" => "TLS 1.3".to_string(), // HTTP/3 uses TLS 1.3 + _ => "TLS 1.2".to_string(), // Default for HTTPS + } + } + } else { + "".to_string() // No TLS for HTTP + }; + + // Determine cipher - use provided cipher or infer from TLS version + let cipher = if let Some(ref provided_cipher) = tls_cipher { + provided_cipher.clone() + } else if scheme == "https" { + match fp.version.as_str() { + "3.0" | "3" => "TLS_AES_256_GCM_SHA384".to_string(), // HTTP/3 uses TLS 1.3 + "2.0" | "2" => "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384".to_string(), // HTTP/2 typically uses TLS 1.2 + _ => "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384".to_string(), // Default TLS 1.2 cipher + } + } else { + "".to_string() // No cipher for HTTP + }; + + // Extract server certificate details if available + let server_cert = extract_server_cert_details(server_cert_info); + + // Extract JA4 from TLS fingerprint data - use tls_ja4 if available + let ja4_value = tls_ja4.clone(); + + Some(TlsDetails { + version: tls_version, + cipher, + alpn: tls_alpn.clone(), + sni: tls_sni.clone(), + ja4: ja4_value, + ja4_unsorted: tls_ja4_unsorted.clone(), + ja4t: ja4t.clone(), + server_cert, + }) + } else if scheme == "https" { + // Create minimal TLS details for HTTPS connections without fingerprint (e.g., PROXY protocol) + let server_cert = extract_server_cert_details(server_cert_info); + + Some(TlsDetails { + version: "TLS 1.3".to_string(), + cipher: "TLS_AES_256_GCM_SHA384".to_string(), + alpn: None, + sni: None, + ja4: Some("t13d".to_string()), + ja4_unsorted: Some("t13d".to_string()), + ja4t: ja4t.clone(), + server_cert, + }) + } else { + None + }; + + // Create HTTP details + let http_details = HttpDetails { + method, + scheme, + host, + port, + path, + query: query.clone(), + query_hash: if query.is_empty() { None } else { Some(format!("{:x}", Sha256::digest(query.as_bytes()))) }, + headers, + ja4h: Some(ja4h_fp.fingerprint.clone()), + user_agent, + content_type, + content_length: Some(req_body_bytes.len() as u64), + body: body_str, + body_sha256, + body_truncated, + }; + + // Create network details + let network_details = NetworkDetails { + src_ip: peer_addr.ip().to_string(), + src_port: peer_addr.port(), + dst_ip: dst_addr.ip().to_string(), + dst_port: dst_addr.port(), + }; + + // Create response details from response_data - response body logging disabled + let response_details = ResponseDetails { + status: response_data.response_json["status"].as_u64().unwrap_or(0) as u16, + status_text: response_data.response_json["status_text"].as_str().unwrap_or("Unknown").to_string(), + content_type: response_data.response_json["content_type"].as_str().map(|s| s.to_string()), + content_length: response_data.response_json["content_length"].as_u64(), + body: "".to_string(), + }; + + // Create remediation details + let remediation_details = Self::create_remediation_details(waf_result, threat_data); + + // Create the access log + let access_log = HttpAccessLog { + event_type: "http_access_log".to_string(), + schema_version: "1.0.0".to_string(), + timestamp, + request_id, + http: http_details, + network: network_details, + tls: tls_details, + response: response_details, + remediation: remediation_details, + upstream: upstream_info, + performance: performance_info, + }; + + // Log to stdout (existing behavior) + if let Err(e) = access_log.log_to_stdout() { + log::warn!("Failed to log access log to stdout: {}", e); + } + + // Send to unified event queue + send_event(UnifiedEvent::HttpAccessLog(access_log)); + + Ok(()) + } + + /// Create remediation details from WAF result and threat intelligence data + fn create_remediation_details( + waf_result: Option<&crate::waf::wirefilter::WafResult>, + threat_data: Option<&crate::threat::ThreatResponse>, + ) -> Option { + // Check if WAF action requires remediation (Block/Challenge/RateLimit) - these will populate WAF fields + // RateLimit is included because it blocks requests when the limit is exceeded + let has_waf_remediation = match waf_result { + Some(waf) => matches!( + waf.action, + crate::waf::wirefilter::WafAction::Block + | crate::waf::wirefilter::WafAction::Challenge + | crate::waf::wirefilter::WafAction::RateLimit + ), + None => false, + }; + + // Check if threat data is meaningful (not just default/empty values) + let has_meaningful_threat_data = threat_data.map(|threat| { + threat.intel.score > 0 + || threat.intel.reason_code != "NO_DATA" + || !threat.intel.categories.is_empty() + || !threat.intel.tags.is_empty() + }).unwrap_or(false); + + // Create remediation section if: + // 1. WAF action is Block/Challenge/RateLimit (will populate WAF fields), OR + // 2. There's meaningful threat intelligence data (will populate threat fields) + // Note: WAF fields (waf_action, waf_rule_id, waf_rule_name) are populated for Block/Challenge/RateLimit + // RateLimit is included because it blocks requests when the limit is exceeded + // Allow actions don't populate WAF fields, but remediation section can still exist if there's meaningful threat data + if !has_waf_remediation && !has_meaningful_threat_data { + return None; + } + + let mut remediation = RemediationDetails { + waf_action: None, + waf_rule_id: None, + waf_rule_name: None, + threat_score: None, + threat_confidence: None, + threat_categories: None, + threat_tags: None, + threat_reason_code: None, + threat_reason_summary: None, + threat_advice: None, + ip_country: None, + ip_asn: None, + ip_asn_org: None, + ip_asn_country: None, + }; + + // Populate WAF data if available for actions that require remediation (Block/Challenge/RateLimit) + // RateLimit actions also populate WAF fields because they block/challenge requests when exceeded + // Allow actions don't populate WAF fields, but remediation section may still exist if there's meaningful threat data + if let Some(waf) = waf_result { + // Include WAF data in remediation for Block, Challenge, and RateLimit actions + // RateLimit is included because it blocks requests when the limit is exceeded + match waf.action { + crate::waf::wirefilter::WafAction::Block + | crate::waf::wirefilter::WafAction::Challenge + | crate::waf::wirefilter::WafAction::RateLimit => { + remediation.waf_action = Some(format!("{:?}", waf.action).to_lowercase()); + remediation.waf_rule_id = Some(waf.rule_id.clone()); + remediation.waf_rule_name = Some(waf.rule_name.clone()); + } + crate::waf::wirefilter::WafAction::Allow => { + // Allow actions don't populate WAF fields + // But remediation section may still exist if there's meaningful threat data + } + } + } + + // Populate threat intelligence data if available and meaningful + if let Some(threat) = threat_data { + // Always populate GeoIP fields (country, ASN) when available, regardless of threat score + // These are geographic/network identifiers that should always be included + let country_code = threat.context.geo.iso_code.clone(); + remediation.ip_country = Some(country_code); + remediation.ip_asn = Some(threat.context.asn); + remediation.ip_asn_org = Some(threat.context.org.clone()); + remediation.ip_asn_country = Some(threat.context.geo.asn_iso_code.clone()); + + // Only include threat-specific data if it's meaningful (not just default/empty values) + if has_meaningful_threat_data { + remediation.threat_score = Some(threat.intel.score); + remediation.threat_confidence = Some(threat.intel.confidence); + remediation.threat_categories = Some(threat.intel.categories.clone()); + remediation.threat_tags = Some(threat.intel.tags.clone()); + remediation.threat_reason_code = Some(threat.intel.reason_code.clone()); + remediation.threat_reason_summary = Some(threat.intel.reason_summary.clone()); + remediation.threat_advice = Some(threat.advice.clone()); + } + } + + // Only return remediation if it has any meaningful data (WAF fields for Block/Challenge, meaningful threat data, or GeoIP data) + let has_waf_data = remediation.waf_action.is_some(); + let has_threat_data = remediation.threat_score.is_some() || remediation.threat_reason_code.is_some(); + let has_geoip_data = remediation.ip_country.is_some() || remediation.ip_asn.is_some(); + + if has_waf_data || has_threat_data || has_geoip_data { + Some(remediation) + } else { + None + } + } + + pub fn to_json(&self) -> Result { + serde_json::to_string(self) + } + + pub fn log_to_stdout(&self) -> Result<(), Box> { + let json = self.to_json()?; + log::info!("{}", json); + Ok(()) + } + + /// Create a lightweight summary suitable for returning with responses + pub fn to_summary(&self) -> AccessLogSummary { + let waf_info = if let Some(remediation) = &self.remediation { + if let (Some(action), Some(rule_id), Some(rule_name)) = + (&remediation.waf_action, &remediation.waf_rule_id, &remediation.waf_rule_name) { + Some(WafInfo { + action: action.clone(), + rule_id: rule_id.clone(), + rule_name: rule_name.clone(), + }) + } else { + None + } + } else { + None + }; + + let threat_info = if let Some(remediation) = &self.remediation { + if let (Some(score), Some(confidence)) = + (remediation.threat_score, remediation.threat_confidence) { + Some(ThreatInfo { + score, + confidence, + categories: remediation.threat_categories.clone().unwrap_or_default(), + reason: remediation.threat_reason_summary.clone().unwrap_or_default(), + country: remediation.ip_country.clone(), + asn: remediation.ip_asn, + }) + } else { + None + } + } else { + None + }; + + let protocol = if self.tls.is_some() { + format!("{} over {}", self.http.scheme, + self.tls.as_ref().map(|t| t.version.as_str()).unwrap_or("TLS")) + } else { + self.http.scheme.clone() + }; + + AccessLogSummary { + request_id: self.request_id.clone(), + timestamp: self.timestamp, + upstream: self.upstream.clone(), + waf: waf_info, + threat: threat_info, + network: NetworkSummary { + src_ip: self.network.src_ip.clone(), + dst_ip: self.network.dst_ip.clone(), + protocol, + }, + performance: self.performance.clone().unwrap_or(PerformanceInfo { + request_time_ms: None, + upstream_time_ms: None, + }), + } + } + + /// Add upstream routing information to the access log + pub fn with_upstream(mut self, upstream: UpstreamInfo) -> Self { + self.upstream = Some(upstream); + self + } + + /// Add performance metrics to the access log + pub fn with_performance(mut self, performance: PerformanceInfo) -> Self { + self.performance = Some(performance); + self + } +} + +/// Helper struct to hold response data for access logging +#[derive(Debug, Clone)] +pub struct ResponseData { + pub response_json: serde_json::Value, + pub blocking_info: Option, + pub waf_result: Option, + pub threat_data: Option, +} + +impl ResponseData { + /// Create response data for a regular response + pub async fn from_response(response: Response>) -> Result> { + let (response_parts, response_body) = response.into_parts(); + let response_body_bytes = response_body.collect().await?.to_bytes(); + let response_body_str = String::from_utf8_lossy(&response_body_bytes).to_string(); + + let response_content_type = response_parts.headers + .get("content-type") + .and_then(|h| h.to_str().ok()) + .map(|s| s.to_string()); + + let response_json = serde_json::json!({ + "status": response_parts.status.as_u16(), + "status_text": response_parts.status.canonical_reason().unwrap_or("Unknown"), + "content_type": response_content_type, + "content_length": response_body_bytes.len() as u64, + "body": response_body_str + }); + + Ok(ResponseData { + response_json, + blocking_info: None, + waf_result: None, + threat_data: None, + }) + } + + /// Create response data for a blocked request + pub fn for_blocked_request( + block_reason: &str, + status_code: u16, + waf_result: Option, + threat_data: Option<&crate::threat::ThreatResponse>, + ) -> Self { + let status_text = match status_code { + 403 => "Forbidden", + 426 => "Upgrade Required", + 429 => "Too Many Requests", + _ => "Blocked" + }; + + let response_json = serde_json::json!({ + "status": status_code, + "status_text": status_text, + "content_type": "application/json", + "content_length": 0, + "body": format!("{{\"ok\":false,\"error\":\"{}\"}}", block_reason) + }); + + let blocking_info = serde_json::json!({ + "blocked": true, + "reason": block_reason, + "filter_type": "waf" + }); + + ResponseData { + response_json, + blocking_info: Some(blocking_info), + waf_result, + threat_data: threat_data.cloned(), + } + } + + /// Create response data for a malware-blocked request with scan details + pub fn for_malware_blocked_request( + signature: Option, + scan_error: Option, + waf_result: Option, + threat_data: Option<&crate::threat::ThreatResponse>, + ) -> Self { + let response_json = serde_json::json!({ + "status": 403, + "status_text": "Forbidden", + "content_type": "application/json", + "content_length": 0, + "body": "{\"ok\":false,\"error\":\"malware_detected\"}" + }); + + let mut blocking_info = serde_json::json!({ + "blocked": true, + "reason": "malware_detected", + "filter_type": "content_scanning", + "malware_detected": true, + }); + + if let Some(sig) = signature { + blocking_info["malware_signature"] = serde_json::Value::String(sig); + } + + if let Some(err) = scan_error { + blocking_info["scan_error"] = serde_json::Value::String(err); + } + + ResponseData { + response_json, + blocking_info: Some(blocking_info), + waf_result, + threat_data: threat_data.cloned(), + } + } +} + + +/// Extract server certificate details from server certificate info +fn extract_server_cert_details(server_cert_info: Option<&ServerCertInfo>) -> Option { + server_cert_info.map(|cert_info| { + // Parse the date strings from ServerCertInfo + let not_before = chrono::DateTime::parse_from_rfc3339(&cert_info.not_before) + .unwrap_or_else(|_| Utc::now().into()) + .with_timezone(&Utc); + let not_after = chrono::DateTime::parse_from_rfc3339(&cert_info.not_after) + .unwrap_or_else(|_| Utc::now().into()) + .with_timezone(&Utc); + + ServerCertDetails { + issuer: cert_info.issuer.clone(), + subject: cert_info.subject.clone(), + not_before, + not_after, + fingerprint_sha256: cert_info.fingerprint_sha256.clone(), + } + }) +} + + +#[cfg(test)] +mod tests { + use super::*; + use hyper::Request; + + #[tokio::test] + async fn test_access_log_creation() { + // Create a simple request + let _req = Request::builder() + .method("GET") + .uri("https://example.com/test?param=value") + .header("User-Agent", format!("TestAgent/{}", env!("CARGO_PKG_VERSION"))) + .body(Full::new(bytes::Bytes::new())) + .unwrap(); + + // Create a simple response + let _response = Response::builder() + .status(200) + .header("Content-Type", "application/json") + .body(Full::new(bytes::Bytes::from("{\"ok\":true}"))) + .unwrap(); + + let _peer: SocketAddr = "127.0.0.1:12345".parse().unwrap(); + let _dst_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + // This test would need more setup to work properly + // For now, just test the structure creation + let log = HttpAccessLog { + event_type: "http_access_log".to_string(), + schema_version: "1.0.0".to_string(), + timestamp: Utc::now(), + request_id: "test_123".to_string(), + http: HttpDetails { + method: "GET".to_string(), + scheme: "https".to_string(), + host: "example.com".to_string(), + port: 443, + path: "/test".to_string(), + query: "param=value".to_string(), + query_hash: Some("abc123".to_string()), + headers: HashMap::new(), + ja4h: Some("g11n_000000000000_000000000000".to_string()), + user_agent: Some(format!("TestAgent/{}", env!("CARGO_PKG_VERSION"))), + content_type: None, + content_length: None, + body: "".to_string(), + body_sha256: "abc123".to_string(), + body_truncated: false, + }, + network: NetworkDetails { + src_ip: "127.0.0.1".to_string(), + src_port: 12345, + dst_ip: "127.0.0.1".to_string(), + dst_port: 443, + }, + tls: None, + response: ResponseDetails { + status: 200, + status_text: "OK".to_string(), + content_type: Some("application/json".to_string()), + content_length: Some(10), + body: "{\"ok\":true}".to_string(), + }, + remediation: None, + upstream: Some(UpstreamInfo { + selected: "backend1".to_string(), + method: "round_robin".to_string(), + reason: "healthy".to_string(), + }), + performance: Some(PerformanceInfo { + request_time_ms: Some(50), + upstream_time_ms: Some(45), + }), + }; + + let json = log.to_json().unwrap(); + assert!(json.contains("http_access_log")); + assert!(json.contains("GET")); + assert!(json.contains("example.com")); + assert!(json.contains("backend1")); + + // Test summary creation + let summary = log.to_summary(); + assert_eq!(summary.request_id, "test_123"); + assert_eq!(summary.upstream.as_ref().unwrap().selected, "backend1"); + assert_eq!(summary.performance.request_time_ms, Some(50)); + } + + #[test] + fn test_remediation_with_threat_intelligence() { + use crate::waf::wirefilter::{WafAction, WafResult}; + use crate::threat::{ThreatResponse, ThreatIntel, ThreatContext, GeoInfo}; + + // Create a mock threat response + let threat_response = ThreatResponse { + schema_version: "1.0.0".to_string(), + tenant_id: "test-tenant".to_string(), + ip: "192.168.1.100".to_string(), + intel: ThreatIntel { + score: 85, + confidence: 0.95, + score_version: "1.0".to_string(), + categories: vec!["malware".to_string(), "botnet".to_string()], + tags: vec!["suspicious".to_string()], + first_seen: Some(Utc::now()), + last_seen: Some(Utc::now()), + source_count: 5, + reason_code: "THREAT_DETECTED".to_string(), + reason_summary: "IP address associated with malicious activity".to_string(), + rule_id: "rule-123".to_string(), + }, + context: ThreatContext { + asn: 12345, + org: "Test ISP".to_string(), + ip_version: 4, + geo: GeoInfo { + country: "United States".to_string(), + iso_code: "US".to_string(), + asn_iso_code: "US".to_string(), + }, + }, + advice: "Block this IP address".to_string(), + ttl_s: 3600, + generated_at: Utc::now(), + }; + + // Create a WAF result with threat intelligence + let waf_result = WafResult { + action: WafAction::Block, + rule_name: "Threat intelligence - Block".to_string(), + rule_id: "aa0880fd-4d3a-41a6-a02b-9b8b83ca615a".to_string(), + rate_limit_config: None, + threat_response: Some(threat_response.clone()), + }; + + // Test create_remediation_details with threat intelligence + let remediation = HttpAccessLog::create_remediation_details( + Some(&waf_result), + Some(&threat_response), + ); + + assert!(remediation.is_some()); + let remediation = remediation.unwrap(); + + // Verify WAF data + assert_eq!(remediation.waf_action, Some("block".to_string())); + assert_eq!(remediation.waf_rule_id, Some("aa0880fd-4d3a-41a6-a02b-9b8b83ca615a".to_string())); + assert_eq!(remediation.waf_rule_name, Some("Threat intelligence - Block".to_string())); + + // Verify threat intelligence data + assert_eq!(remediation.threat_score, Some(85)); + assert_eq!(remediation.threat_confidence, Some(0.95)); + assert_eq!(remediation.threat_categories, Some(vec!["malware".to_string(), "botnet".to_string()])); + assert_eq!(remediation.threat_tags, Some(vec!["suspicious".to_string()])); + assert_eq!(remediation.threat_reason_code, Some("THREAT_DETECTED".to_string())); + assert_eq!(remediation.threat_reason_summary, Some("IP address associated with malicious activity".to_string())); + assert_eq!(remediation.threat_advice, Some("Block this IP address".to_string())); + assert_eq!(remediation.ip_country, Some("US".to_string())); + assert_eq!(remediation.ip_asn, Some(12345)); + assert_eq!(remediation.ip_asn_org, Some("Test ISP".to_string())); + assert_eq!(remediation.ip_asn_country, Some("US".to_string())); + } + + #[test] + fn test_remediation_with_waf_challenge_and_threat_intelligence() { + use crate::waf::wirefilter::{WafAction, WafResult}; + use crate::threat::{ThreatResponse, ThreatIntel, ThreatContext, GeoInfo}; + + // Create a mock threat response + let threat_response = ThreatResponse { + schema_version: "1.0.0".to_string(), + tenant_id: "test-tenant".to_string(), + ip: "10.0.0.1".to_string(), + intel: ThreatIntel { + score: 60, + confidence: 0.75, + score_version: "1.0".to_string(), + categories: vec!["suspicious".to_string()], + tags: vec!["review".to_string()], + first_seen: Some(Utc::now()), + last_seen: Some(Utc::now()), + source_count: 2, + reason_code: "SUSPICIOUS_ACTIVITY".to_string(), + reason_summary: "Unusual traffic patterns detected".to_string(), + rule_id: "rule-456".to_string(), + }, + context: ThreatContext { + asn: 67890, + org: "Another ISP".to_string(), + ip_version: 4, + geo: GeoInfo { + country: "Canada".to_string(), + iso_code: "CA".to_string(), + asn_iso_code: "CA".to_string(), + }, + }, + advice: "Challenge with CAPTCHA".to_string(), + ttl_s: 1800, + generated_at: Utc::now(), + }; + + // Create a WAF result with challenge action and threat intelligence + let waf_result = WafResult { + action: WafAction::Challenge, + rule_name: "Threat intelligence - Challenge".to_string(), + rule_id: "1eb12716-6e13-4e23-a1d9-c879f6175317".to_string(), + rate_limit_config: None, + threat_response: Some(threat_response.clone()), + }; + + // Test create_remediation_details with challenge action and threat intelligence + let remediation = HttpAccessLog::create_remediation_details( + Some(&waf_result), + Some(&threat_response), + ); + + assert!(remediation.is_some()); + let remediation = remediation.unwrap(); + + // Verify WAF data (challenge should be included in remediation) + assert_eq!(remediation.waf_action, Some("challenge".to_string())); + assert_eq!(remediation.waf_rule_id, Some("1eb12716-6e13-4e23-a1d9-c879f6175317".to_string())); + assert_eq!(remediation.waf_rule_name, Some("Threat intelligence - Challenge".to_string())); + + // Verify threat intelligence data + assert_eq!(remediation.threat_score, Some(60)); + assert_eq!(remediation.threat_confidence, Some(0.75)); + assert_eq!(remediation.threat_categories, Some(vec!["suspicious".to_string()])); + assert_eq!(remediation.ip_country, Some("CA".to_string())); + assert_eq!(remediation.ip_asn, Some(67890)); + } + + #[test] + fn test_remediation_without_threat_intelligence() { + use crate::waf::wirefilter::{WafAction, WafResult}; + + // Create a WAF result without threat intelligence + let waf_result = WafResult { + action: WafAction::Block, + rule_name: "Custom Rule".to_string(), + rule_id: "custom-rule-123".to_string(), + rate_limit_config: None, + threat_response: None, + }; + + // Test create_remediation_details without threat intelligence + let remediation = HttpAccessLog::create_remediation_details( + Some(&waf_result), + None, + ); + + assert!(remediation.is_some()); + let remediation = remediation.unwrap(); + + // Verify WAF data is present + assert_eq!(remediation.waf_action, Some("block".to_string())); + assert_eq!(remediation.waf_rule_id, Some("custom-rule-123".to_string())); + + // Verify threat intelligence data is null + assert_eq!(remediation.threat_score, None); + assert_eq!(remediation.threat_confidence, None); + assert_eq!(remediation.threat_categories, None); + assert_eq!(remediation.ip_country, None); + assert_eq!(remediation.ip_asn, None); + } + + #[test] + fn test_remediation_json_serialization_with_threat_intelligence() { + use crate::waf::wirefilter::{WafAction, WafResult}; + use crate::threat::{ThreatResponse, ThreatIntel, ThreatContext, GeoInfo}; + + // Create a mock threat response + let threat_response = ThreatResponse { + schema_version: "1.0.0".to_string(), + tenant_id: "test-tenant".to_string(), + ip: "192.168.1.100".to_string(), + intel: ThreatIntel { + score: 90, + confidence: 0.98, + score_version: "1.0".to_string(), + categories: vec!["malware".to_string()], + tags: vec!["critical".to_string()], + first_seen: Some(Utc::now()), + last_seen: Some(Utc::now()), + source_count: 10, + reason_code: "MALWARE_DETECTED".to_string(), + reason_summary: "Known malware source".to_string(), + rule_id: "rule-789".to_string(), + }, + context: ThreatContext { + asn: 99999, + org: "Malicious Network".to_string(), + ip_version: 4, + geo: GeoInfo { + country: "Unknown".to_string(), + iso_code: "XX".to_string(), + asn_iso_code: "XX".to_string(), + }, + }, + advice: "Immediate block required".to_string(), + ttl_s: 7200, + generated_at: Utc::now(), + }; + + let waf_result = WafResult { + action: WafAction::Block, + rule_name: "Threat intelligence - Block".to_string(), + rule_id: "test-rule-id".to_string(), + rate_limit_config: None, + threat_response: Some(threat_response.clone()), + }; + + let remediation = HttpAccessLog::create_remediation_details( + Some(&waf_result), + Some(&threat_response), + ).unwrap(); + + // Create a full access log with remediation + let access_log = HttpAccessLog { + event_type: "http_access_log".to_string(), + schema_version: "1.0.0".to_string(), + timestamp: Utc::now(), + request_id: "test_req_123".to_string(), + http: HttpDetails { + method: "GET".to_string(), + scheme: "https".to_string(), + host: "example.com".to_string(), + port: 443, + path: "/".to_string(), + query: "".to_string(), + query_hash: None, + headers: HashMap::new(), + ja4h: None, + user_agent: None, + content_type: None, + content_length: None, + body: "".to_string(), + body_sha256: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string(), + body_truncated: false, + }, + network: NetworkDetails { + src_ip: "192.168.1.100".to_string(), + src_port: 12345, + dst_ip: "10.0.0.1".to_string(), + dst_port: 443, + }, + tls: None, + response: ResponseDetails { + status: 403, + status_text: "Forbidden".to_string(), + content_type: None, + content_length: None, + body: "".to_string(), + }, + remediation: Some(remediation), + upstream: None, + performance: None, + }; + + // Serialize to JSON and verify threat intelligence fields are present + let json = access_log.to_json().unwrap(); + assert!(json.contains("\"threat_score\":90")); + assert!(json.contains("\"threat_confidence\":0.98")); + assert!(json.contains("\"threat_categories\":[\"malware\"]")); + assert!(json.contains("\"ip_country\":\"XX\"")); + assert!(json.contains("\"ip_asn\":99999")); + assert!(json.contains("\"threat_reason_code\":\"MALWARE_DETECTED\"")); + assert!(json.contains("\"threat_reason_summary\":\"Known malware source\"")); + } +} diff --git a/src/access_rules.rs b/src/access_rules.rs index 22e5653..194feda 100644 --- a/src/access_rules.rs +++ b/src/access_rules.rs @@ -1,719 +1,719 @@ -use std::collections::HashSet; -use std::net::{Ipv4Addr, Ipv6Addr, IpAddr}; -use std::str::FromStr; -use std::sync::{Arc, Mutex, OnceLock}; - -use crate::bpf; -use crate::worker::config; -use crate::worker::config::global_config; -use crate::waf::wirefilter::update_http_filter_from_config_value; -use crate::firewall::{Firewall, SYNAPSEFirewall, NftablesFirewall, IptablesFirewall}; -use crate::utils::http_utils::{parse_ip_or_cidr, is_ip_in_cidr}; - -// Store previous rules state for comparison -type PreviousRules = Arc>>; -type PreviousRulesV6 = Arc>>; - -// Global previous rules state for the access rules worker -static PREVIOUS_RULES: OnceLock = OnceLock::new(); -static PREVIOUS_RULES_V6: OnceLock = OnceLock::new(); - -fn get_previous_rules() -> &'static PreviousRules { - PREVIOUS_RULES.get_or_init(|| Arc::new(Mutex::new(HashSet::new()))) -} - -fn get_previous_rules_v6() -> &'static PreviousRulesV6 { - PREVIOUS_RULES_V6.get_or_init(|| Arc::new(Mutex::new(HashSet::new()))) -} - -/// Apply access rules once using the current global config snapshot (initial setup) -pub fn init_access_rules_from_global( - skels: &Vec>>, -) -> Result<(), Box> { - if skels.is_empty() { - return Ok(()); - } - if let Ok(guard) = global_config().read() { - if let Some(cfg_arc) = guard.as_ref() { - let previous_rules: PreviousRules = Arc::new(Mutex::new(std::collections::HashSet::new())); - let previous_rules_v6: PreviousRulesV6 = Arc::new(Mutex::new(std::collections::HashSet::new())); - // Use Arc clone instead of full Config clone for efficiency - let resp = config::ConfigApiResponse { success: true, config: (**cfg_arc).clone() }; - apply_rules(skels, &resp, &previous_rules, &previous_rules_v6)?; - } - } - Ok(()) -} - -/// Apply access rules and WAF rules from global config snapshot -/// This is called periodically by the ConfigWorker after it fetches new config -/// Set `skip_waf_update` to true in agent mode to skip WAF wirefilter updates -pub fn apply_rules_from_global( - skels: &Vec>>, - previous_rules: &PreviousRules, - previous_rules_v6: &PreviousRulesV6, - skip_waf_update: bool, -) -> Result<(), Box> { - // Read from global config and apply if available - if let Ok(guard) = global_config().read() { - if let Some(cfg_arc) = guard.as_ref() { - // Update WAF wirefilter when config changes (skip in agent mode) - if !skip_waf_update { - if let Err(e) = update_http_filter_from_config_value(cfg_arc) { - log::error!("failed to update HTTP filter from config: {e}"); - } - } - if skels.is_empty() { - return Ok(()); - } - // Use Arc clone instead of full Config clone for efficiency - apply_rules( - skels, - &config::ConfigApiResponse { success: true, config: (**cfg_arc).clone() }, - previous_rules, - previous_rules_v6, - )?; - return Ok(()); - } - } - Ok(()) -} - -/// Apply access rules and WAF rules from global config using global state -/// This is called by the ConfigWorker after it fetches new config -/// Set `skip_waf_update` to true in agent mode to skip WAF wirefilter updates -pub fn apply_rules_from_global_with_state( - skels: &Vec>>, - skip_waf_update: bool, -) -> Result<(), Box> { - apply_rules_from_global(skels, get_previous_rules(), get_previous_rules_v6(), skip_waf_update) -} - -/// Apply access rules using nftables firewall backend -/// This is used when BPF/XDP is not available -pub fn apply_rules_nftables( - nft_fw: &Arc>, - previous_rules: &PreviousRules, - previous_rules_v6: &PreviousRulesV6, -) -> Result<(), Box> { - if let Ok(guard) = global_config().read() { - if let Some(cfg_arc) = guard.as_ref() { - let resp = config::ConfigApiResponse { success: true, config: (**cfg_arc).clone() }; - apply_rules_to_nftables(nft_fw, &resp, previous_rules, previous_rules_v6)?; - } - } - Ok(()) -} - -/// Apply access rules to nftables firewall -fn apply_rules_to_nftables( - nft_fw: &Arc>, - resp: &config::ConfigApiResponse, - previous_rules: &PreviousRules, - previous_rules_v6: &PreviousRulesV6, -) -> Result<(), Box> { - fn parse_ipv4_ip_or_cidr(entry: &str) -> Option<(Ipv4Addr, u32)> { - let s = entry.trim(); - if s.is_empty() { - return None; - } - if let Some(slash) = s.find('/') { - let (ip_str, prefix_str) = s.split_at(slash); - let prefix_str = &prefix_str[1..]; - if let (Ok(ip), Ok(prefix)) = (Ipv4Addr::from_str(ip_str), prefix_str.parse::()) { - return Some((ip, prefix)); - } - } else if let Ok(ip) = Ipv4Addr::from_str(s) { - return Some((ip, 32)); - } - None - } - - fn parse_ipv6_ip_or_cidr(entry: &str) -> Option<(Ipv6Addr, u32)> { - let s = entry.trim(); - if s.is_empty() { - return None; - } - if let Some(slash) = s.find('/') { - let (ip_str, prefix_str) = s.split_at(slash); - let prefix_str = &prefix_str[1..]; - if let (Ok(ip), Ok(prefix)) = (Ipv6Addr::from_str(ip_str), prefix_str.parse::()) { - return Some((ip, prefix)); - } - } else if let Ok(ip) = Ipv6Addr::from_str(s) { - return Some((ip, 128)); - } - None - } - - // Build current rules sets from config - let mut current_rules: HashSet<(Ipv4Addr, u32)> = HashSet::new(); - let mut current_rules_v6: HashSet<(Ipv6Addr, u32)> = HashSet::new(); - - // Process block rules - for entry in &resp.config.access_rules.block.ips { - if let Some(parsed) = parse_ipv4_ip_or_cidr(entry) { - current_rules.insert(parsed); - } else if let Some(parsed) = parse_ipv6_ip_or_cidr(entry) { - current_rules_v6.insert(parsed); - } - } - - // Process country-based block rules - for country_map in &resp.config.access_rules.block.country { - for (_country_code, ip_list) in country_map.iter() { - for ip_str in ip_list { - if let Some(parsed) = parse_ipv4_ip_or_cidr(ip_str) { - current_rules.insert(parsed); - } else if let Some(parsed) = parse_ipv6_ip_or_cidr(ip_str) { - current_rules_v6.insert(parsed); - } - } - } - } - - // Process ASN-based block rules - for asn_map in &resp.config.access_rules.block.asn { - for (_asn, ip_list) in asn_map.iter() { - for ip_str in ip_list { - if let Some(parsed) = parse_ipv4_ip_or_cidr(ip_str) { - current_rules.insert(parsed); - } else if let Some(parsed) = parse_ipv6_ip_or_cidr(ip_str) { - current_rules_v6.insert(parsed); - } - } - } - } - - // Lock previous rules - let mut previous_rules_guard = previous_rules.lock().map_err(|e| format!("Lock error: {}", e))?; - let mut previous_rules_v6_guard = previous_rules_v6.lock().map_err(|e| format!("Lock error: {}", e))?; - - // Check for changes - let ipv4_changed = current_rules != *previous_rules_guard; - let ipv6_changed = current_rules_v6 != *previous_rules_v6_guard; - - if !ipv4_changed && !ipv6_changed { - return Ok(()); - } - - // Compute diffs - let prev_v4_snapshot = previous_rules_guard.clone(); - let prev_v6_snapshot = previous_rules_v6_guard.clone(); - let removed_v4: Vec<(Ipv4Addr, u32)> = prev_v4_snapshot.difference(¤t_rules).cloned().collect(); - let added_v4: Vec<(Ipv4Addr, u32)> = current_rules.difference(&prev_v4_snapshot).cloned().collect(); - let removed_v6: Vec<(Ipv6Addr, u32)> = prev_v6_snapshot.difference(¤t_rules_v6).cloned().collect(); - let added_v6: Vec<(Ipv6Addr, u32)> = current_rules_v6.difference(&prev_v6_snapshot).cloned().collect(); - - // Apply to nftables - let mut fw = nft_fw.lock().map_err(|e| format!("Lock error: {}", e))?; - - if ipv4_changed { - for (net, prefix) in &removed_v4 { - if let Err(e) = fw.unban_ip(*net, *prefix) { - log::error!("nftables: IPv4 unban failed for {}/{}: {}", net, prefix, e); - } - } - for (net, prefix) in &added_v4 { - if let Err(e) = fw.ban_ip(*net, *prefix) { - log::error!("nftables: IPv4 ban failed for {}/{}: {}", net, prefix, e); - } - } - log::debug!("nftables: Applied {} IPv4 rule changes (+{}, -{})", - added_v4.len() + removed_v4.len(), added_v4.len(), removed_v4.len()); - } - - if ipv6_changed { - for (net, prefix) in &removed_v6 { - if let Err(e) = fw.unban_ipv6(*net, *prefix) { - log::error!("nftables: IPv6 unban failed for {}/{}: {}", net, prefix, e); - } - } - for (net, prefix) in &added_v6 { - if let Err(e) = fw.ban_ipv6(*net, *prefix) { - log::error!("nftables: IPv6 ban failed for {}/{}: {}", net, prefix, e); - } - } - log::debug!("nftables: Applied {} IPv6 rule changes (+{}, -{})", - added_v6.len() + removed_v6.len(), added_v6.len(), removed_v6.len()); - } - - // Update previous snapshots - if ipv4_changed { *previous_rules_guard = current_rules; } - if ipv6_changed { *previous_rules_v6_guard = current_rules_v6; } - - Ok(()) -} - -/// Initialize access rules for nftables backend -pub fn init_access_rules_nftables( - nft_fw: &Arc>, -) -> Result<(), Box> { - let previous_rules: PreviousRules = Arc::new(Mutex::new(HashSet::new())); - let previous_rules_v6: PreviousRulesV6 = Arc::new(Mutex::new(HashSet::new())); - apply_rules_nftables(nft_fw, &previous_rules, &previous_rules_v6) -} - -/// Initialize access rules for iptables backend -pub fn init_access_rules_iptables( - ipt_fw: &Arc>, -) -> Result<(), Box> { - let previous_rules: PreviousRules = Arc::new(Mutex::new(HashSet::new())); - let previous_rules_v6: PreviousRulesV6 = Arc::new(Mutex::new(HashSet::new())); - apply_rules_iptables(ipt_fw, &previous_rules, &previous_rules_v6) -} - -/// Apply access rules using iptables firewall backend -pub fn apply_rules_iptables( - ipt_fw: &Arc>, - previous_rules: &PreviousRules, - previous_rules_v6: &PreviousRulesV6, -) -> Result<(), Box> { - if let Ok(guard) = global_config().read() { - if let Some(cfg_arc) = guard.as_ref() { - let resp = config::ConfigApiResponse { success: true, config: (**cfg_arc).clone() }; - apply_rules_to_iptables(ipt_fw, &resp, previous_rules, previous_rules_v6)?; - } - } - Ok(()) -} - -/// Apply access rules to iptables firewall -fn apply_rules_to_iptables( - ipt_fw: &Arc>, - resp: &config::ConfigApiResponse, - previous_rules: &PreviousRules, - previous_rules_v6: &PreviousRulesV6, -) -> Result<(), Box> { - fn parse_ipv4_ip_or_cidr(entry: &str) -> Option<(Ipv4Addr, u32)> { - let s = entry.trim(); - if s.is_empty() { - return None; - } - if let Some(slash) = s.find('/') { - let (ip_str, prefix_str) = s.split_at(slash); - let prefix_str = &prefix_str[1..]; - if let (Ok(ip), Ok(prefix)) = (Ipv4Addr::from_str(ip_str), prefix_str.parse::()) { - return Some((ip, prefix)); - } - } else if let Ok(ip) = Ipv4Addr::from_str(s) { - return Some((ip, 32)); - } - None - } - - fn parse_ipv6_ip_or_cidr(entry: &str) -> Option<(Ipv6Addr, u32)> { - let s = entry.trim(); - if s.is_empty() { - return None; - } - if let Some(slash) = s.find('/') { - let (ip_str, prefix_str) = s.split_at(slash); - let prefix_str = &prefix_str[1..]; - if let (Ok(ip), Ok(prefix)) = (Ipv6Addr::from_str(ip_str), prefix_str.parse::()) { - return Some((ip, prefix)); - } - } else if let Ok(ip) = Ipv6Addr::from_str(s) { - return Some((ip, 128)); - } - None - } - - // Build current rules sets from config - let mut current_rules: HashSet<(Ipv4Addr, u32)> = HashSet::new(); - let mut current_rules_v6: HashSet<(Ipv6Addr, u32)> = HashSet::new(); - - // Process block rules - for entry in &resp.config.access_rules.block.ips { - if let Some(parsed) = parse_ipv4_ip_or_cidr(entry) { - current_rules.insert(parsed); - } else if let Some(parsed) = parse_ipv6_ip_or_cidr(entry) { - current_rules_v6.insert(parsed); - } - } - - // Process country-based block rules - for country_map in &resp.config.access_rules.block.country { - for (_country_code, ip_list) in country_map.iter() { - for ip_entry in ip_list { - if let Some(parsed) = parse_ipv4_ip_or_cidr(ip_entry) { - current_rules.insert(parsed); - } else if let Some(parsed) = parse_ipv6_ip_or_cidr(ip_entry) { - current_rules_v6.insert(parsed); - } - } - } - } - - // Process ASN-based block rules - for asn_map in &resp.config.access_rules.block.asn { - for (_asn, ip_list) in asn_map.iter() { - for ip_entry in ip_list { - if let Some(parsed) = parse_ipv4_ip_or_cidr(ip_entry) { - current_rules.insert(parsed); - } else if let Some(parsed) = parse_ipv6_ip_or_cidr(ip_entry) { - current_rules_v6.insert(parsed); - } - } - } - } - - // Get previous rules - let mut previous_rules_guard = previous_rules.lock().map_err(|e| format!("Lock error: {}", e))?; - let mut previous_rules_v6_guard = previous_rules_v6.lock().map_err(|e| format!("Lock error: {}", e))?; - - // Find IPs to add (in current but not in previous) - let to_add: Vec<_> = current_rules.difference(&*previous_rules_guard).cloned().collect(); - let to_add_v6: Vec<_> = current_rules_v6.difference(&*previous_rules_v6_guard).cloned().collect(); - - // Find IPs to remove (in previous but not in current) - let to_remove: Vec<_> = previous_rules_guard.difference(¤t_rules).cloned().collect(); - let to_remove_v6: Vec<_> = previous_rules_v6_guard.difference(¤t_rules_v6).cloned().collect(); - - let ipv4_changed = !to_add.is_empty() || !to_remove.is_empty(); - let ipv6_changed = !to_add_v6.is_empty() || !to_remove_v6.is_empty(); - - // Apply changes - if ipv4_changed || ipv6_changed { - let mut fw = ipt_fw.lock().map_err(|e| format!("Lock error: {}", e))?; - - // Add new rules - for (ip, prefix) in &to_add { - if let Err(e) = fw.ban_ip(*ip, *prefix) { - log::error!("iptables: IPv4 ban failed for {}/{}: {}", ip, prefix, e); - } - } - for (ip, prefix) in &to_add_v6 { - if let Err(e) = fw.ban_ipv6(*ip, *prefix) { - log::error!("iptables: IPv6 ban failed for {}/{}: {}", ip, prefix, e); - } - } - - // Remove old rules - for (ip, prefix) in &to_remove { - if let Err(e) = fw.unban_ip(*ip, *prefix) { - log::error!("iptables: IPv4 unban failed for {}/{}: {}", ip, prefix, e); - } - } - for (ip, prefix) in &to_remove_v6 { - if let Err(e) = fw.unban_ipv6(*ip, *prefix) { - log::error!("iptables: IPv6 unban failed for {}/{}: {}", ip, prefix, e); - } - } - - if !to_add.is_empty() || !to_remove.is_empty() { - log::info!("iptables: IPv4 rules updated (+{} -{} total={})", to_add.len(), to_remove.len(), current_rules.len()); - } - if !to_add_v6.is_empty() || !to_remove_v6.is_empty() { - log::info!("iptables: IPv6 rules updated (+{} -{} total={})", to_add_v6.len(), to_remove_v6.len(), current_rules_v6.len()); - } - } - - // Update previous rules - if ipv4_changed { *previous_rules_guard = current_rules; } - if ipv6_changed { *previous_rules_v6_guard = current_rules_v6; } - - Ok(()) -} - -fn apply_rules( - skels: &Vec>>, - resp: &config::ConfigApiResponse, - previous_rules: &PreviousRules, - previous_rules_v6: &PreviousRulesV6, -) -> Result<(), Box> { - fn parse_ipv4_ip_or_cidr(entry: &str) -> Option<(Ipv4Addr, u32)> { - let s = entry.trim(); - if s.is_empty() { - return None; - } - if s.contains(':') { - // IPv6 not supported by IPv4 map - return None; - } - if !s.contains('/') { - return Ipv4Addr::from_str(s).ok().map(|ip| (ip, 32)); - } - let mut parts = s.split('/'); - let ip_str = parts.next()?.trim(); - let prefix_str = parts.next()?.trim(); - if parts.next().is_some() { - // malformed - return None; - } - let ip = Ipv4Addr::from_str(ip_str).ok()?; - let prefix: u32 = prefix_str.parse::().ok()? as u32; - if prefix > 32 { - return None; - } - let ip_u32 = u32::from(ip); - let mask = if prefix == 0 { - 0 - } else { - u32::MAX.checked_shl(32 - prefix).unwrap_or(0) - }; - let net = Ipv4Addr::from(ip_u32 & mask); - Some((net, prefix)) - } - - // Helper: parse IPv6 or IPv6/CIDR into (network, prefix) - fn parse_ipv6_ip_or_cidr(entry: &str) -> Option<(Ipv6Addr, u32)> { - let s = entry.trim(); - if s.is_empty() { - return None; - } - if !s.contains(':') { - // IPv4 not supported by IPv6 map - return None; - } - if !s.contains('/') { - return Ipv6Addr::from_str(s).ok().map(|ip| (ip, 128)); - } - let mut parts = s.split('/'); - let ip_str = parts.next()?.trim(); - let prefix_str = parts.next()?.trim(); - if parts.next().is_some() { - // malformed - return None; - } - let ip = Ipv6Addr::from_str(ip_str).ok()?; - let prefix: u32 = prefix_str.parse::().ok()? as u32; - if prefix > 128 { - return None; - } - Some((ip, prefix)) - } - - let mut current_rules: HashSet<(Ipv4Addr, u32)> = HashSet::new(); - let mut current_rules_v6: HashSet<(Ipv6Addr, u32)> = HashSet::new(); - - let rule = &resp.config.access_rules; - - // Parse block.ips - for ip_str in &rule.block.ips { - if ip_str.contains(':') { - // IPv6 address - if let Some((net, prefix)) = parse_ipv6_ip_or_cidr(ip_str) { - current_rules_v6.insert((net, prefix)); - } else { - log::warn!("invalid IPv6 ip/cidr ignored: {}", ip_str); - } - } else { - // IPv4 address - if let Some((net, prefix)) = parse_ipv4_ip_or_cidr(ip_str) { - current_rules.insert((net, prefix)); - } else { - log::warn!("invalid IPv4 ip/cidr ignored: {}", ip_str); - } - } - } - - // Parse block.country values - for country_map in &rule.block.country { - for (_cc, list) in country_map.iter() { - for ip_str in list { - if ip_str.contains(':') { - // IPv6 address - if let Some((net, prefix)) = parse_ipv6_ip_or_cidr(ip_str) { - current_rules_v6.insert((net, prefix)); - } else { - log::warn!("invalid IPv6 ip/cidr ignored: {}", ip_str); - } - } else { - // IPv4 address - if let Some((net, prefix)) = parse_ipv4_ip_or_cidr(ip_str) { - current_rules.insert((net, prefix)); - } else { - log::warn!("invalid IPv4 ip/cidr ignored: {}", ip_str); - } - } - } - } - } - - // Parse block.asn values - for asn_map in &rule.block.asn { - for (_asn, list) in asn_map.iter() { - for ip_str in list { - if ip_str.contains(':') { - // IPv6 address - if let Some((net, prefix)) = parse_ipv6_ip_or_cidr(ip_str) { - current_rules_v6.insert((net, prefix)); - } else { - log::warn!("invalid IPv6 ip/cidr ignored: {}", ip_str); - } - } else { - // IPv4 address - if let Some((net, prefix)) = parse_ipv4_ip_or_cidr(ip_str) { - current_rules.insert((net, prefix)); - } else { - log::warn!("invalid IPv4 ip/cidr ignored: {}", ip_str); - } - } - } - } - } - - // Compare with previous rules to detect changes - let mut previous_rules_guard = previous_rules.lock().unwrap(); - let mut previous_rules_v6_guard = previous_rules_v6.lock().unwrap(); - - // Check if rules have changed - let ipv4_changed = *previous_rules_guard != current_rules; - let ipv6_changed = *previous_rules_v6_guard != current_rules_v6; - - // If neither family changed, skip quietly with a single log entry - if !ipv4_changed && !ipv6_changed { - log::debug!("No IPv4 or IPv6 access rule changes detected, skipping BPF map updates"); - return Ok(()); - } - - // Compute diffs once against snapshots - let prev_v4_snapshot = previous_rules_guard.clone(); - let prev_v6_snapshot = previous_rules_v6_guard.clone(); - let removed_v4: Vec<(Ipv4Addr, u32)> = prev_v4_snapshot.difference(¤t_rules).cloned().collect(); - let added_v4: Vec<(Ipv4Addr, u32)> = current_rules.difference(&prev_v4_snapshot).cloned().collect(); - let removed_v6: Vec<(Ipv6Addr, u32)> = prev_v6_snapshot.difference(¤t_rules_v6).cloned().collect(); - let added_v6: Vec<(Ipv6Addr, u32)> = current_rules_v6.difference(&prev_v6_snapshot).cloned().collect(); - - // Apply to all BPF skeletons - for s in skels.iter() { - let mut fw = SYNAPSEFirewall::new(s); - if ipv4_changed { - for (net, prefix) in &removed_v4 { - if let Err(e) = fw.unban_ip(*net, *prefix) { - log::error!("IPv4 unban failed for {}/{}: {}", net, prefix, e); - } - } - for (net, prefix) in &added_v4 { - if let Err(e) = fw.ban_ip(*net, *prefix) { - log::error!("IPv4 ban failed for {}/{}: {}", net, prefix, e); - } - } - } - if ipv6_changed { - for (net, prefix) in &removed_v6 { - if let Err(e) = fw.unban_ipv6(*net, *prefix) { - log::error!("IPv6 unban failed for {}/{}: {}", net, prefix, e); - } - } - for (net, prefix) in &added_v6 { - if let Err(e) = fw.ban_ipv6(*net, *prefix) { - log::error!("IPv6 ban failed for {}/{}: {}", net, prefix, e); - } - } - } - } - - // Update previous snapshots once after applying to all skels - if ipv4_changed { *previous_rules_guard = current_rules; } - if ipv6_changed { *previous_rules_v6_guard = current_rules_v6; } - - Ok(()) -} - -/// Check if an IP address is allowed by access rules -/// Returns true if the IP is explicitly allowed, false otherwise -pub fn is_ip_allowed_by_access_rules(ip: IpAddr) -> bool { - if let Ok(guard) = global_config().read() { - if let Some(cfg) = guard.as_ref() { - let allow_rules = &cfg.access_rules.allow; - - // Check direct IP matches - for ip_str in &allow_rules.ips { - if let Ok(allowed_ip) = ip_str.parse::() { - if ip == allowed_ip { - return true; - } - } - - // Check CIDR ranges - if let Some((network, prefix_len)) = parse_ip_or_cidr(ip_str) { - if is_ip_in_cidr(ip, network, prefix_len) { - return true; - } - } - } - - // Check country-based allow rules - for country_map in &allow_rules.country { - for (_country_code, ip_list) in country_map.iter() { - for ip_str in ip_list { - if let Some((network, prefix_len)) = parse_ip_or_cidr(ip_str) { - if is_ip_in_cidr(ip, network, prefix_len) { - return true; - } - } - } - } - } - - // Check ASN-based allow rules - for asn_map in &allow_rules.asn { - for (_asn, ip_list) in asn_map.iter() { - for ip_str in ip_list { - if let Some((network, prefix_len)) = parse_ip_or_cidr(ip_str) { - if is_ip_in_cidr(ip, network, prefix_len) { - return true; - } - } - } - } - } - } - } - false -} - -/// Check if an IP address is blocked by access rules -/// Returns true if the IP should be blocked, false otherwise -pub fn is_ip_blocked_by_access_rules(ip: IpAddr) -> bool { - if let Ok(guard) = global_config().read() { - if let Some(cfg) = guard.as_ref() { - let block_rules = &cfg.access_rules.block; - - // Check direct IP matches - for ip_str in &block_rules.ips { - if let Ok(blocked_ip) = ip_str.parse::() { - if ip == blocked_ip { - return true; - } - } - - // Check CIDR ranges - if let Some((network, prefix_len)) = parse_ip_or_cidr(ip_str) { - if is_ip_in_cidr(ip, network, prefix_len) { - return true; - } - } - } - - // Check country-based block rules - for country_map in &block_rules.country { - for (_country_code, ip_list) in country_map.iter() { - for ip_str in ip_list { - if let Some((network, prefix_len)) = parse_ip_or_cidr(ip_str) { - if is_ip_in_cidr(ip, network, prefix_len) { - return true; - } - } - } - } - } - - // Check ASN-based block rules - for asn_map in &block_rules.asn { - for (_asn, ip_list) in asn_map.iter() { - for ip_str in ip_list { - if let Some((network, prefix_len)) = parse_ip_or_cidr(ip_str) { - if is_ip_in_cidr(ip, network, prefix_len) { - return true; - } - } - } - } - } - } - } - false -} +use std::collections::HashSet; +use std::net::{Ipv4Addr, Ipv6Addr, IpAddr}; +use std::str::FromStr; +use std::sync::{Arc, Mutex, OnceLock}; + +use crate::bpf; +use crate::worker::config; +use crate::worker::config::global_config; +use crate::waf::wirefilter::update_http_filter_from_config_value; +use crate::firewall::{Firewall, SYNAPSEFirewall, NftablesFirewall, IptablesFirewall}; +use crate::utils::http_utils::{parse_ip_or_cidr, is_ip_in_cidr}; + +// Store previous rules state for comparison +type PreviousRules = Arc>>; +type PreviousRulesV6 = Arc>>; + +// Global previous rules state for the access rules worker +static PREVIOUS_RULES: OnceLock = OnceLock::new(); +static PREVIOUS_RULES_V6: OnceLock = OnceLock::new(); + +fn get_previous_rules() -> &'static PreviousRules { + PREVIOUS_RULES.get_or_init(|| Arc::new(Mutex::new(HashSet::new()))) +} + +fn get_previous_rules_v6() -> &'static PreviousRulesV6 { + PREVIOUS_RULES_V6.get_or_init(|| Arc::new(Mutex::new(HashSet::new()))) +} + +/// Apply access rules once using the current global config snapshot (initial setup) +pub fn init_access_rules_from_global( + skels: &Vec>>, +) -> Result<(), Box> { + if skels.is_empty() { + return Ok(()); + } + if let Ok(guard) = global_config().read() { + if let Some(cfg_arc) = guard.as_ref() { + let previous_rules: PreviousRules = Arc::new(Mutex::new(std::collections::HashSet::new())); + let previous_rules_v6: PreviousRulesV6 = Arc::new(Mutex::new(std::collections::HashSet::new())); + // Use Arc clone instead of full Config clone for efficiency + let resp = config::ConfigApiResponse { success: true, config: (**cfg_arc).clone() }; + apply_rules(skels, &resp, &previous_rules, &previous_rules_v6)?; + } + } + Ok(()) +} + +/// Apply access rules and WAF rules from global config snapshot +/// This is called periodically by the ConfigWorker after it fetches new config +/// Set `skip_waf_update` to true in agent mode to skip WAF wirefilter updates +pub fn apply_rules_from_global( + skels: &Vec>>, + previous_rules: &PreviousRules, + previous_rules_v6: &PreviousRulesV6, + skip_waf_update: bool, +) -> Result<(), Box> { + // Read from global config and apply if available + if let Ok(guard) = global_config().read() { + if let Some(cfg_arc) = guard.as_ref() { + // Update WAF wirefilter when config changes (skip in agent mode) + if !skip_waf_update { + if let Err(e) = update_http_filter_from_config_value(cfg_arc) { + log::error!("failed to update HTTP filter from config: {e}"); + } + } + if skels.is_empty() { + return Ok(()); + } + // Use Arc clone instead of full Config clone for efficiency + apply_rules( + skels, + &config::ConfigApiResponse { success: true, config: (**cfg_arc).clone() }, + previous_rules, + previous_rules_v6, + )?; + return Ok(()); + } + } + Ok(()) +} + +/// Apply access rules and WAF rules from global config using global state +/// This is called by the ConfigWorker after it fetches new config +/// Set `skip_waf_update` to true in agent mode to skip WAF wirefilter updates +pub fn apply_rules_from_global_with_state( + skels: &Vec>>, + skip_waf_update: bool, +) -> Result<(), Box> { + apply_rules_from_global(skels, get_previous_rules(), get_previous_rules_v6(), skip_waf_update) +} + +/// Apply access rules using nftables firewall backend +/// This is used when BPF/XDP is not available +pub fn apply_rules_nftables( + nft_fw: &Arc>, + previous_rules: &PreviousRules, + previous_rules_v6: &PreviousRulesV6, +) -> Result<(), Box> { + if let Ok(guard) = global_config().read() { + if let Some(cfg_arc) = guard.as_ref() { + let resp = config::ConfigApiResponse { success: true, config: (**cfg_arc).clone() }; + apply_rules_to_nftables(nft_fw, &resp, previous_rules, previous_rules_v6)?; + } + } + Ok(()) +} + +/// Apply access rules to nftables firewall +fn apply_rules_to_nftables( + nft_fw: &Arc>, + resp: &config::ConfigApiResponse, + previous_rules: &PreviousRules, + previous_rules_v6: &PreviousRulesV6, +) -> Result<(), Box> { + fn parse_ipv4_ip_or_cidr(entry: &str) -> Option<(Ipv4Addr, u32)> { + let s = entry.trim(); + if s.is_empty() { + return None; + } + if let Some(slash) = s.find('/') { + let (ip_str, prefix_str) = s.split_at(slash); + let prefix_str = &prefix_str[1..]; + if let (Ok(ip), Ok(prefix)) = (Ipv4Addr::from_str(ip_str), prefix_str.parse::()) { + return Some((ip, prefix)); + } + } else if let Ok(ip) = Ipv4Addr::from_str(s) { + return Some((ip, 32)); + } + None + } + + fn parse_ipv6_ip_or_cidr(entry: &str) -> Option<(Ipv6Addr, u32)> { + let s = entry.trim(); + if s.is_empty() { + return None; + } + if let Some(slash) = s.find('/') { + let (ip_str, prefix_str) = s.split_at(slash); + let prefix_str = &prefix_str[1..]; + if let (Ok(ip), Ok(prefix)) = (Ipv6Addr::from_str(ip_str), prefix_str.parse::()) { + return Some((ip, prefix)); + } + } else if let Ok(ip) = Ipv6Addr::from_str(s) { + return Some((ip, 128)); + } + None + } + + // Build current rules sets from config + let mut current_rules: HashSet<(Ipv4Addr, u32)> = HashSet::new(); + let mut current_rules_v6: HashSet<(Ipv6Addr, u32)> = HashSet::new(); + + // Process block rules + for entry in &resp.config.access_rules.block.ips { + if let Some(parsed) = parse_ipv4_ip_or_cidr(entry) { + current_rules.insert(parsed); + } else if let Some(parsed) = parse_ipv6_ip_or_cidr(entry) { + current_rules_v6.insert(parsed); + } + } + + // Process country-based block rules + for country_map in &resp.config.access_rules.block.country { + for (_country_code, ip_list) in country_map.iter() { + for ip_str in ip_list { + if let Some(parsed) = parse_ipv4_ip_or_cidr(ip_str) { + current_rules.insert(parsed); + } else if let Some(parsed) = parse_ipv6_ip_or_cidr(ip_str) { + current_rules_v6.insert(parsed); + } + } + } + } + + // Process ASN-based block rules + for asn_map in &resp.config.access_rules.block.asn { + for (_asn, ip_list) in asn_map.iter() { + for ip_str in ip_list { + if let Some(parsed) = parse_ipv4_ip_or_cidr(ip_str) { + current_rules.insert(parsed); + } else if let Some(parsed) = parse_ipv6_ip_or_cidr(ip_str) { + current_rules_v6.insert(parsed); + } + } + } + } + + // Lock previous rules + let mut previous_rules_guard = previous_rules.lock().map_err(|e| format!("Lock error: {}", e))?; + let mut previous_rules_v6_guard = previous_rules_v6.lock().map_err(|e| format!("Lock error: {}", e))?; + + // Check for changes + let ipv4_changed = current_rules != *previous_rules_guard; + let ipv6_changed = current_rules_v6 != *previous_rules_v6_guard; + + if !ipv4_changed && !ipv6_changed { + return Ok(()); + } + + // Compute diffs + let prev_v4_snapshot = previous_rules_guard.clone(); + let prev_v6_snapshot = previous_rules_v6_guard.clone(); + let removed_v4: Vec<(Ipv4Addr, u32)> = prev_v4_snapshot.difference(¤t_rules).cloned().collect(); + let added_v4: Vec<(Ipv4Addr, u32)> = current_rules.difference(&prev_v4_snapshot).cloned().collect(); + let removed_v6: Vec<(Ipv6Addr, u32)> = prev_v6_snapshot.difference(¤t_rules_v6).cloned().collect(); + let added_v6: Vec<(Ipv6Addr, u32)> = current_rules_v6.difference(&prev_v6_snapshot).cloned().collect(); + + // Apply to nftables + let mut fw = nft_fw.lock().map_err(|e| format!("Lock error: {}", e))?; + + if ipv4_changed { + for (net, prefix) in &removed_v4 { + if let Err(e) = fw.unban_ip(*net, *prefix) { + log::error!("nftables: IPv4 unban failed for {}/{}: {}", net, prefix, e); + } + } + for (net, prefix) in &added_v4 { + if let Err(e) = fw.ban_ip(*net, *prefix) { + log::error!("nftables: IPv4 ban failed for {}/{}: {}", net, prefix, e); + } + } + log::debug!("nftables: Applied {} IPv4 rule changes (+{}, -{})", + added_v4.len() + removed_v4.len(), added_v4.len(), removed_v4.len()); + } + + if ipv6_changed { + for (net, prefix) in &removed_v6 { + if let Err(e) = fw.unban_ipv6(*net, *prefix) { + log::error!("nftables: IPv6 unban failed for {}/{}: {}", net, prefix, e); + } + } + for (net, prefix) in &added_v6 { + if let Err(e) = fw.ban_ipv6(*net, *prefix) { + log::error!("nftables: IPv6 ban failed for {}/{}: {}", net, prefix, e); + } + } + log::debug!("nftables: Applied {} IPv6 rule changes (+{}, -{})", + added_v6.len() + removed_v6.len(), added_v6.len(), removed_v6.len()); + } + + // Update previous snapshots + if ipv4_changed { *previous_rules_guard = current_rules; } + if ipv6_changed { *previous_rules_v6_guard = current_rules_v6; } + + Ok(()) +} + +/// Initialize access rules for nftables backend +pub fn init_access_rules_nftables( + nft_fw: &Arc>, +) -> Result<(), Box> { + let previous_rules: PreviousRules = Arc::new(Mutex::new(HashSet::new())); + let previous_rules_v6: PreviousRulesV6 = Arc::new(Mutex::new(HashSet::new())); + apply_rules_nftables(nft_fw, &previous_rules, &previous_rules_v6) +} + +/// Initialize access rules for iptables backend +pub fn init_access_rules_iptables( + ipt_fw: &Arc>, +) -> Result<(), Box> { + let previous_rules: PreviousRules = Arc::new(Mutex::new(HashSet::new())); + let previous_rules_v6: PreviousRulesV6 = Arc::new(Mutex::new(HashSet::new())); + apply_rules_iptables(ipt_fw, &previous_rules, &previous_rules_v6) +} + +/// Apply access rules using iptables firewall backend +pub fn apply_rules_iptables( + ipt_fw: &Arc>, + previous_rules: &PreviousRules, + previous_rules_v6: &PreviousRulesV6, +) -> Result<(), Box> { + if let Ok(guard) = global_config().read() { + if let Some(cfg_arc) = guard.as_ref() { + let resp = config::ConfigApiResponse { success: true, config: (**cfg_arc).clone() }; + apply_rules_to_iptables(ipt_fw, &resp, previous_rules, previous_rules_v6)?; + } + } + Ok(()) +} + +/// Apply access rules to iptables firewall +fn apply_rules_to_iptables( + ipt_fw: &Arc>, + resp: &config::ConfigApiResponse, + previous_rules: &PreviousRules, + previous_rules_v6: &PreviousRulesV6, +) -> Result<(), Box> { + fn parse_ipv4_ip_or_cidr(entry: &str) -> Option<(Ipv4Addr, u32)> { + let s = entry.trim(); + if s.is_empty() { + return None; + } + if let Some(slash) = s.find('/') { + let (ip_str, prefix_str) = s.split_at(slash); + let prefix_str = &prefix_str[1..]; + if let (Ok(ip), Ok(prefix)) = (Ipv4Addr::from_str(ip_str), prefix_str.parse::()) { + return Some((ip, prefix)); + } + } else if let Ok(ip) = Ipv4Addr::from_str(s) { + return Some((ip, 32)); + } + None + } + + fn parse_ipv6_ip_or_cidr(entry: &str) -> Option<(Ipv6Addr, u32)> { + let s = entry.trim(); + if s.is_empty() { + return None; + } + if let Some(slash) = s.find('/') { + let (ip_str, prefix_str) = s.split_at(slash); + let prefix_str = &prefix_str[1..]; + if let (Ok(ip), Ok(prefix)) = (Ipv6Addr::from_str(ip_str), prefix_str.parse::()) { + return Some((ip, prefix)); + } + } else if let Ok(ip) = Ipv6Addr::from_str(s) { + return Some((ip, 128)); + } + None + } + + // Build current rules sets from config + let mut current_rules: HashSet<(Ipv4Addr, u32)> = HashSet::new(); + let mut current_rules_v6: HashSet<(Ipv6Addr, u32)> = HashSet::new(); + + // Process block rules + for entry in &resp.config.access_rules.block.ips { + if let Some(parsed) = parse_ipv4_ip_or_cidr(entry) { + current_rules.insert(parsed); + } else if let Some(parsed) = parse_ipv6_ip_or_cidr(entry) { + current_rules_v6.insert(parsed); + } + } + + // Process country-based block rules + for country_map in &resp.config.access_rules.block.country { + for (_country_code, ip_list) in country_map.iter() { + for ip_entry in ip_list { + if let Some(parsed) = parse_ipv4_ip_or_cidr(ip_entry) { + current_rules.insert(parsed); + } else if let Some(parsed) = parse_ipv6_ip_or_cidr(ip_entry) { + current_rules_v6.insert(parsed); + } + } + } + } + + // Process ASN-based block rules + for asn_map in &resp.config.access_rules.block.asn { + for (_asn, ip_list) in asn_map.iter() { + for ip_entry in ip_list { + if let Some(parsed) = parse_ipv4_ip_or_cidr(ip_entry) { + current_rules.insert(parsed); + } else if let Some(parsed) = parse_ipv6_ip_or_cidr(ip_entry) { + current_rules_v6.insert(parsed); + } + } + } + } + + // Get previous rules + let mut previous_rules_guard = previous_rules.lock().map_err(|e| format!("Lock error: {}", e))?; + let mut previous_rules_v6_guard = previous_rules_v6.lock().map_err(|e| format!("Lock error: {}", e))?; + + // Find IPs to add (in current but not in previous) + let to_add: Vec<_> = current_rules.difference(&*previous_rules_guard).cloned().collect(); + let to_add_v6: Vec<_> = current_rules_v6.difference(&*previous_rules_v6_guard).cloned().collect(); + + // Find IPs to remove (in previous but not in current) + let to_remove: Vec<_> = previous_rules_guard.difference(¤t_rules).cloned().collect(); + let to_remove_v6: Vec<_> = previous_rules_v6_guard.difference(¤t_rules_v6).cloned().collect(); + + let ipv4_changed = !to_add.is_empty() || !to_remove.is_empty(); + let ipv6_changed = !to_add_v6.is_empty() || !to_remove_v6.is_empty(); + + // Apply changes + if ipv4_changed || ipv6_changed { + let mut fw = ipt_fw.lock().map_err(|e| format!("Lock error: {}", e))?; + + // Add new rules + for (ip, prefix) in &to_add { + if let Err(e) = fw.ban_ip(*ip, *prefix) { + log::error!("iptables: IPv4 ban failed for {}/{}: {}", ip, prefix, e); + } + } + for (ip, prefix) in &to_add_v6 { + if let Err(e) = fw.ban_ipv6(*ip, *prefix) { + log::error!("iptables: IPv6 ban failed for {}/{}: {}", ip, prefix, e); + } + } + + // Remove old rules + for (ip, prefix) in &to_remove { + if let Err(e) = fw.unban_ip(*ip, *prefix) { + log::error!("iptables: IPv4 unban failed for {}/{}: {}", ip, prefix, e); + } + } + for (ip, prefix) in &to_remove_v6 { + if let Err(e) = fw.unban_ipv6(*ip, *prefix) { + log::error!("iptables: IPv6 unban failed for {}/{}: {}", ip, prefix, e); + } + } + + if !to_add.is_empty() || !to_remove.is_empty() { + log::info!("iptables: IPv4 rules updated (+{} -{} total={})", to_add.len(), to_remove.len(), current_rules.len()); + } + if !to_add_v6.is_empty() || !to_remove_v6.is_empty() { + log::info!("iptables: IPv6 rules updated (+{} -{} total={})", to_add_v6.len(), to_remove_v6.len(), current_rules_v6.len()); + } + } + + // Update previous rules + if ipv4_changed { *previous_rules_guard = current_rules; } + if ipv6_changed { *previous_rules_v6_guard = current_rules_v6; } + + Ok(()) +} + +fn apply_rules( + skels: &Vec>>, + resp: &config::ConfigApiResponse, + previous_rules: &PreviousRules, + previous_rules_v6: &PreviousRulesV6, +) -> Result<(), Box> { + fn parse_ipv4_ip_or_cidr(entry: &str) -> Option<(Ipv4Addr, u32)> { + let s = entry.trim(); + if s.is_empty() { + return None; + } + if s.contains(':') { + // IPv6 not supported by IPv4 map + return None; + } + if !s.contains('/') { + return Ipv4Addr::from_str(s).ok().map(|ip| (ip, 32)); + } + let mut parts = s.split('/'); + let ip_str = parts.next()?.trim(); + let prefix_str = parts.next()?.trim(); + if parts.next().is_some() { + // malformed + return None; + } + let ip = Ipv4Addr::from_str(ip_str).ok()?; + let prefix: u32 = prefix_str.parse::().ok()? as u32; + if prefix > 32 { + return None; + } + let ip_u32 = u32::from(ip); + let mask = if prefix == 0 { + 0 + } else { + u32::MAX.checked_shl(32 - prefix).unwrap_or(0) + }; + let net = Ipv4Addr::from(ip_u32 & mask); + Some((net, prefix)) + } + + // Helper: parse IPv6 or IPv6/CIDR into (network, prefix) + fn parse_ipv6_ip_or_cidr(entry: &str) -> Option<(Ipv6Addr, u32)> { + let s = entry.trim(); + if s.is_empty() { + return None; + } + if !s.contains(':') { + // IPv4 not supported by IPv6 map + return None; + } + if !s.contains('/') { + return Ipv6Addr::from_str(s).ok().map(|ip| (ip, 128)); + } + let mut parts = s.split('/'); + let ip_str = parts.next()?.trim(); + let prefix_str = parts.next()?.trim(); + if parts.next().is_some() { + // malformed + return None; + } + let ip = Ipv6Addr::from_str(ip_str).ok()?; + let prefix: u32 = prefix_str.parse::().ok()? as u32; + if prefix > 128 { + return None; + } + Some((ip, prefix)) + } + + let mut current_rules: HashSet<(Ipv4Addr, u32)> = HashSet::new(); + let mut current_rules_v6: HashSet<(Ipv6Addr, u32)> = HashSet::new(); + + let rule = &resp.config.access_rules; + + // Parse block.ips + for ip_str in &rule.block.ips { + if ip_str.contains(':') { + // IPv6 address + if let Some((net, prefix)) = parse_ipv6_ip_or_cidr(ip_str) { + current_rules_v6.insert((net, prefix)); + } else { + log::warn!("invalid IPv6 ip/cidr ignored: {}", ip_str); + } + } else { + // IPv4 address + if let Some((net, prefix)) = parse_ipv4_ip_or_cidr(ip_str) { + current_rules.insert((net, prefix)); + } else { + log::warn!("invalid IPv4 ip/cidr ignored: {}", ip_str); + } + } + } + + // Parse block.country values + for country_map in &rule.block.country { + for (_cc, list) in country_map.iter() { + for ip_str in list { + if ip_str.contains(':') { + // IPv6 address + if let Some((net, prefix)) = parse_ipv6_ip_or_cidr(ip_str) { + current_rules_v6.insert((net, prefix)); + } else { + log::warn!("invalid IPv6 ip/cidr ignored: {}", ip_str); + } + } else { + // IPv4 address + if let Some((net, prefix)) = parse_ipv4_ip_or_cidr(ip_str) { + current_rules.insert((net, prefix)); + } else { + log::warn!("invalid IPv4 ip/cidr ignored: {}", ip_str); + } + } + } + } + } + + // Parse block.asn values + for asn_map in &rule.block.asn { + for (_asn, list) in asn_map.iter() { + for ip_str in list { + if ip_str.contains(':') { + // IPv6 address + if let Some((net, prefix)) = parse_ipv6_ip_or_cidr(ip_str) { + current_rules_v6.insert((net, prefix)); + } else { + log::warn!("invalid IPv6 ip/cidr ignored: {}", ip_str); + } + } else { + // IPv4 address + if let Some((net, prefix)) = parse_ipv4_ip_or_cidr(ip_str) { + current_rules.insert((net, prefix)); + } else { + log::warn!("invalid IPv4 ip/cidr ignored: {}", ip_str); + } + } + } + } + } + + // Compare with previous rules to detect changes + let mut previous_rules_guard = previous_rules.lock().unwrap(); + let mut previous_rules_v6_guard = previous_rules_v6.lock().unwrap(); + + // Check if rules have changed + let ipv4_changed = *previous_rules_guard != current_rules; + let ipv6_changed = *previous_rules_v6_guard != current_rules_v6; + + // If neither family changed, skip quietly with a single log entry + if !ipv4_changed && !ipv6_changed { + log::debug!("No IPv4 or IPv6 access rule changes detected, skipping BPF map updates"); + return Ok(()); + } + + // Compute diffs once against snapshots + let prev_v4_snapshot = previous_rules_guard.clone(); + let prev_v6_snapshot = previous_rules_v6_guard.clone(); + let removed_v4: Vec<(Ipv4Addr, u32)> = prev_v4_snapshot.difference(¤t_rules).cloned().collect(); + let added_v4: Vec<(Ipv4Addr, u32)> = current_rules.difference(&prev_v4_snapshot).cloned().collect(); + let removed_v6: Vec<(Ipv6Addr, u32)> = prev_v6_snapshot.difference(¤t_rules_v6).cloned().collect(); + let added_v6: Vec<(Ipv6Addr, u32)> = current_rules_v6.difference(&prev_v6_snapshot).cloned().collect(); + + // Apply to all BPF skeletons + for s in skels.iter() { + let mut fw = SYNAPSEFirewall::new(s); + if ipv4_changed { + for (net, prefix) in &removed_v4 { + if let Err(e) = fw.unban_ip(*net, *prefix) { + log::error!("IPv4 unban failed for {}/{}: {}", net, prefix, e); + } + } + for (net, prefix) in &added_v4 { + if let Err(e) = fw.ban_ip(*net, *prefix) { + log::error!("IPv4 ban failed for {}/{}: {}", net, prefix, e); + } + } + } + if ipv6_changed { + for (net, prefix) in &removed_v6 { + if let Err(e) = fw.unban_ipv6(*net, *prefix) { + log::error!("IPv6 unban failed for {}/{}: {}", net, prefix, e); + } + } + for (net, prefix) in &added_v6 { + if let Err(e) = fw.ban_ipv6(*net, *prefix) { + log::error!("IPv6 ban failed for {}/{}: {}", net, prefix, e); + } + } + } + } + + // Update previous snapshots once after applying to all skels + if ipv4_changed { *previous_rules_guard = current_rules; } + if ipv6_changed { *previous_rules_v6_guard = current_rules_v6; } + + Ok(()) +} + +/// Check if an IP address is allowed by access rules +/// Returns true if the IP is explicitly allowed, false otherwise +pub fn is_ip_allowed_by_access_rules(ip: IpAddr) -> bool { + if let Ok(guard) = global_config().read() { + if let Some(cfg) = guard.as_ref() { + let allow_rules = &cfg.access_rules.allow; + + // Check direct IP matches + for ip_str in &allow_rules.ips { + if let Ok(allowed_ip) = ip_str.parse::() { + if ip == allowed_ip { + return true; + } + } + + // Check CIDR ranges + if let Some((network, prefix_len)) = parse_ip_or_cidr(ip_str) { + if is_ip_in_cidr(ip, network, prefix_len) { + return true; + } + } + } + + // Check country-based allow rules + for country_map in &allow_rules.country { + for (_country_code, ip_list) in country_map.iter() { + for ip_str in ip_list { + if let Some((network, prefix_len)) = parse_ip_or_cidr(ip_str) { + if is_ip_in_cidr(ip, network, prefix_len) { + return true; + } + } + } + } + } + + // Check ASN-based allow rules + for asn_map in &allow_rules.asn { + for (_asn, ip_list) in asn_map.iter() { + for ip_str in ip_list { + if let Some((network, prefix_len)) = parse_ip_or_cidr(ip_str) { + if is_ip_in_cidr(ip, network, prefix_len) { + return true; + } + } + } + } + } + } + } + false +} + +/// Check if an IP address is blocked by access rules +/// Returns true if the IP should be blocked, false otherwise +pub fn is_ip_blocked_by_access_rules(ip: IpAddr) -> bool { + if let Ok(guard) = global_config().read() { + if let Some(cfg) = guard.as_ref() { + let block_rules = &cfg.access_rules.block; + + // Check direct IP matches + for ip_str in &block_rules.ips { + if let Ok(blocked_ip) = ip_str.parse::() { + if ip == blocked_ip { + return true; + } + } + + // Check CIDR ranges + if let Some((network, prefix_len)) = parse_ip_or_cidr(ip_str) { + if is_ip_in_cidr(ip, network, prefix_len) { + return true; + } + } + } + + // Check country-based block rules + for country_map in &block_rules.country { + for (_country_code, ip_list) in country_map.iter() { + for ip_str in ip_list { + if let Some((network, prefix_len)) = parse_ip_or_cidr(ip_str) { + if is_ip_in_cidr(ip, network, prefix_len) { + return true; + } + } + } + } + } + + // Check ASN-based block rules + for asn_map in &block_rules.asn { + for (_asn, ip_list) in asn_map.iter() { + for ip_str in ip_list { + if let Some((network, prefix_len)) = parse_ip_or_cidr(ip_str) { + if is_ip_in_cidr(ip, network, prefix_len) { + return true; + } + } + } + } + } + } + } + false +} diff --git a/src/acme/config.rs b/src/acme/config.rs index fa50b7d..29757d6 100644 --- a/src/acme/config.rs +++ b/src/acme/config.rs @@ -1,263 +1,263 @@ -use std::path::PathBuf; -use serde::{Deserialize, Serialize}; -use anyhow::{Result, Context}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Config { - pub https_path: PathBuf, - pub cert_path: PathBuf, - pub key_path: PathBuf, - pub static_path: PathBuf, - pub opts: ConfigOpts, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ConfigOpts { - pub ip: String, - pub port: u16, - pub domain: String, - pub email: Option, - pub https_dns: bool, - pub development: bool, - pub dns_lookup_max_attempts: Option, - pub dns_lookup_delay_seconds: Option, - pub storage_type: Option, - pub redis_url: Option, - pub lock_ttl_seconds: Option, - pub redis_ssl: Option, - pub challenge_max_ttl_seconds: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AppConfig { - pub server: ServerConfig, - pub storage: StorageConfig, - pub acme: AcmeConfig, - pub domains: crate::acme::domain_reader::DomainSourceConfig, - #[serde(default)] - pub logging: LoggingConfig, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ServerConfig { - pub ip: String, - pub port: u16, - /// Run as daemon (background process) - #[serde(default)] - pub daemon: bool, - /// PID file path (for daemon mode) - pub pid_file: Option, - /// Working directory for daemon - pub working_directory: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LoggingConfig { - /// Logging output: "stdout", "syslog", or "journald" - #[serde(default = "default_log_output")] - pub output: String, - /// Log level: trace, debug, info, warn, error - #[serde(default = "default_log_level")] - pub level: String, - /// Syslog facility (for syslog output) - #[serde(default = "default_syslog_facility")] - pub syslog_facility: String, - /// Syslog identifier/tag (for syslog output) - #[serde(default = "default_syslog_identifier")] - pub syslog_identifier: String, -} - -impl Default for LoggingConfig { - fn default() -> Self { - Self { - output: default_log_output(), - level: default_log_level(), - syslog_facility: default_syslog_facility(), - syslog_identifier: default_syslog_identifier(), - } - } -} - -fn default_log_output() -> String { - "stdout".to_string() -} - -fn default_log_level() -> String { - "info".to_string() -} - -fn default_syslog_facility() -> String { - "daemon".to_string() -} - -fn default_syslog_identifier() -> String { - "ssl-storage".to_string() -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct StorageConfig { - #[serde(rename = "type")] - pub storage_type: String, - pub https_path: String, - pub redis_url: Option, - #[serde(default = "default_lock_ttl_seconds")] - pub lock_ttl_seconds: u64, - /// Redis SSL/TLS configuration - #[serde(default)] - pub redis_ssl: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RedisSslConfig { - /// Path to CA certificate file (PEM format) - pub ca_cert_path: Option, - /// Path to client certificate file (PEM format, optional) - pub client_cert_path: Option, - /// Path to client private key file (PEM format, optional) - pub client_key_path: Option, - /// Skip certificate verification (for testing with self-signed certs) - #[serde(default)] - pub insecure: bool, -} - -fn default_lock_ttl_seconds() -> u64 { - 900 // 15 minutes default -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AcmeConfig { - pub email: String, - pub development: bool, - #[serde(default = "default_dns_lookup_config")] - pub dns_lookup: DnsLookupConfig, - #[serde(default = "default_retry_config")] - pub retry: RetryConfig, - /// Maximum TTL in seconds for ACME challenges (default: 3600 = 1 hour) - /// Challenges older than this will be considered expired and regenerated - #[serde(default = "default_challenge_max_ttl_seconds")] - pub challenge_max_ttl_seconds: u64, -} - -fn default_challenge_max_ttl_seconds() -> u64 { - 3600 // 1 hour - ACME challenges are typically valid for a limited time -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DnsLookupConfig { - #[serde(default = "default_max_attempts")] - pub max_attempts: u32, - #[serde(default = "default_delay_seconds")] - pub delay_seconds: u64, -} - -fn default_dns_lookup_config() -> DnsLookupConfig { - DnsLookupConfig { - max_attempts: default_max_attempts(), - delay_seconds: default_delay_seconds(), - } -} - -fn default_max_attempts() -> u32 { - 100 -} - -fn default_delay_seconds() -> u64 { - 10 -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RetryConfig { - /// Enable periodic rechecking of failed certificates - #[serde(default = "default_enable_periodic_check")] - pub enable_periodic_check: bool, - /// Interval in seconds between periodic checks (default: 3600 = 1 hour) - #[serde(default = "default_check_interval_seconds")] - pub check_interval_seconds: u64, - /// Minimum delay in seconds before retrying a failed certificate (default: 300 = 5 minutes) - #[serde(default = "default_min_retry_delay_seconds")] - pub min_retry_delay_seconds: u64, - /// Maximum delay in seconds before retrying a failed certificate (default: 86400 = 24 hours) - #[serde(default = "default_max_retry_delay_seconds")] - pub max_retry_delay_seconds: u64, - /// Maximum number of retries before giving up (0 = unlimited, default: 0) - #[serde(default = "default_max_retries")] - pub max_retries: u32, -} - -fn default_retry_config() -> RetryConfig { - RetryConfig { - enable_periodic_check: default_enable_periodic_check(), - check_interval_seconds: default_check_interval_seconds(), - min_retry_delay_seconds: default_min_retry_delay_seconds(), - max_retry_delay_seconds: default_max_retry_delay_seconds(), - max_retries: default_max_retries(), - } -} - -fn default_enable_periodic_check() -> bool { - true -} - -fn default_check_interval_seconds() -> u64 { - 3600 // 1 hour -} - -fn default_min_retry_delay_seconds() -> u64 { - 300 // 5 minutes -} - -fn default_max_retry_delay_seconds() -> u64 { - 86400 // 24 hours -} - -fn default_max_retries() -> u32 { - 0 // unlimited -} - -impl AppConfig { - /// Load configuration from YAML file - pub fn from_file(path: impl AsRef) -> Result { - use std::fs; - let content = fs::read_to_string(path) - .with_context(|| "Failed to read config file")?; - let config: AppConfig = serde_yaml::from_str(&content) - .with_context(|| "Failed to parse config YAML")?; - Ok(config) - } - - /// Create a domain-specific Config from AppConfig and DomainConfig - pub fn create_domain_config(&self, domain: &crate::acme::domain_reader::DomainConfig, https_path: PathBuf) -> Config { - let mut domain_https_path = https_path.clone(); - domain_https_path.push(&domain.domain); - - let mut cert_path = domain_https_path.clone(); - cert_path.push("cert.pem"); - let mut key_path = domain_https_path.clone(); - key_path.push("key.pem"); - let static_path = domain_https_path.clone(); - - Config { - https_path: domain_https_path, - cert_path, - key_path, - static_path, - opts: ConfigOpts { - ip: self.server.ip.clone(), - port: self.server.port, - domain: domain.domain.clone(), - email: domain.email.clone().or_else(|| Some(self.acme.email.clone())), - https_dns: domain.dns, - development: self.acme.development, - dns_lookup_max_attempts: Some(self.acme.dns_lookup.max_attempts), - dns_lookup_delay_seconds: Some(self.acme.dns_lookup.delay_seconds), - storage_type: Some(self.storage.storage_type.clone()), - redis_url: self.storage.redis_url.clone(), - lock_ttl_seconds: Some(self.storage.lock_ttl_seconds), - redis_ssl: self.storage.redis_ssl.clone(), - challenge_max_ttl_seconds: Some(self.acme.challenge_max_ttl_seconds), - }, - } - } -} - - +use std::path::PathBuf; +use serde::{Deserialize, Serialize}; +use anyhow::{Result, Context}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + pub https_path: PathBuf, + pub cert_path: PathBuf, + pub key_path: PathBuf, + pub static_path: PathBuf, + pub opts: ConfigOpts, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConfigOpts { + pub ip: String, + pub port: u16, + pub domain: String, + pub email: Option, + pub https_dns: bool, + pub development: bool, + pub dns_lookup_max_attempts: Option, + pub dns_lookup_delay_seconds: Option, + pub storage_type: Option, + pub redis_url: Option, + pub lock_ttl_seconds: Option, + pub redis_ssl: Option, + pub challenge_max_ttl_seconds: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AppConfig { + pub server: ServerConfig, + pub storage: StorageConfig, + pub acme: AcmeConfig, + pub domains: crate::acme::domain_reader::DomainSourceConfig, + #[serde(default)] + pub logging: LoggingConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerConfig { + pub ip: String, + pub port: u16, + /// Run as daemon (background process) + #[serde(default)] + pub daemon: bool, + /// PID file path (for daemon mode) + pub pid_file: Option, + /// Working directory for daemon + pub working_directory: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LoggingConfig { + /// Logging output: "stdout", "syslog", or "journald" + #[serde(default = "default_log_output")] + pub output: String, + /// Log level: trace, debug, info, warn, error + #[serde(default = "default_log_level")] + pub level: String, + /// Syslog facility (for syslog output) + #[serde(default = "default_syslog_facility")] + pub syslog_facility: String, + /// Syslog identifier/tag (for syslog output) + #[serde(default = "default_syslog_identifier")] + pub syslog_identifier: String, +} + +impl Default for LoggingConfig { + fn default() -> Self { + Self { + output: default_log_output(), + level: default_log_level(), + syslog_facility: default_syslog_facility(), + syslog_identifier: default_syslog_identifier(), + } + } +} + +fn default_log_output() -> String { + "stdout".to_string() +} + +fn default_log_level() -> String { + "info".to_string() +} + +fn default_syslog_facility() -> String { + "daemon".to_string() +} + +fn default_syslog_identifier() -> String { + "ssl-storage".to_string() +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StorageConfig { + #[serde(rename = "type")] + pub storage_type: String, + pub https_path: String, + pub redis_url: Option, + #[serde(default = "default_lock_ttl_seconds")] + pub lock_ttl_seconds: u64, + /// Redis SSL/TLS configuration + #[serde(default)] + pub redis_ssl: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RedisSslConfig { + /// Path to CA certificate file (PEM format) + pub ca_cert_path: Option, + /// Path to client certificate file (PEM format, optional) + pub client_cert_path: Option, + /// Path to client private key file (PEM format, optional) + pub client_key_path: Option, + /// Skip certificate verification (for testing with self-signed certs) + #[serde(default)] + pub insecure: bool, +} + +fn default_lock_ttl_seconds() -> u64 { + 900 // 15 minutes default +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AcmeConfig { + pub email: String, + pub development: bool, + #[serde(default = "default_dns_lookup_config")] + pub dns_lookup: DnsLookupConfig, + #[serde(default = "default_retry_config")] + pub retry: RetryConfig, + /// Maximum TTL in seconds for ACME challenges (default: 3600 = 1 hour) + /// Challenges older than this will be considered expired and regenerated + #[serde(default = "default_challenge_max_ttl_seconds")] + pub challenge_max_ttl_seconds: u64, +} + +fn default_challenge_max_ttl_seconds() -> u64 { + 3600 // 1 hour - ACME challenges are typically valid for a limited time +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DnsLookupConfig { + #[serde(default = "default_max_attempts")] + pub max_attempts: u32, + #[serde(default = "default_delay_seconds")] + pub delay_seconds: u64, +} + +fn default_dns_lookup_config() -> DnsLookupConfig { + DnsLookupConfig { + max_attempts: default_max_attempts(), + delay_seconds: default_delay_seconds(), + } +} + +fn default_max_attempts() -> u32 { + 100 +} + +fn default_delay_seconds() -> u64 { + 10 +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RetryConfig { + /// Enable periodic rechecking of failed certificates + #[serde(default = "default_enable_periodic_check")] + pub enable_periodic_check: bool, + /// Interval in seconds between periodic checks (default: 3600 = 1 hour) + #[serde(default = "default_check_interval_seconds")] + pub check_interval_seconds: u64, + /// Minimum delay in seconds before retrying a failed certificate (default: 300 = 5 minutes) + #[serde(default = "default_min_retry_delay_seconds")] + pub min_retry_delay_seconds: u64, + /// Maximum delay in seconds before retrying a failed certificate (default: 86400 = 24 hours) + #[serde(default = "default_max_retry_delay_seconds")] + pub max_retry_delay_seconds: u64, + /// Maximum number of retries before giving up (0 = unlimited, default: 0) + #[serde(default = "default_max_retries")] + pub max_retries: u32, +} + +fn default_retry_config() -> RetryConfig { + RetryConfig { + enable_periodic_check: default_enable_periodic_check(), + check_interval_seconds: default_check_interval_seconds(), + min_retry_delay_seconds: default_min_retry_delay_seconds(), + max_retry_delay_seconds: default_max_retry_delay_seconds(), + max_retries: default_max_retries(), + } +} + +fn default_enable_periodic_check() -> bool { + true +} + +fn default_check_interval_seconds() -> u64 { + 3600 // 1 hour +} + +fn default_min_retry_delay_seconds() -> u64 { + 300 // 5 minutes +} + +fn default_max_retry_delay_seconds() -> u64 { + 86400 // 24 hours +} + +fn default_max_retries() -> u32 { + 0 // unlimited +} + +impl AppConfig { + /// Load configuration from YAML file + pub fn from_file(path: impl AsRef) -> Result { + use std::fs; + let content = fs::read_to_string(path) + .with_context(|| "Failed to read config file")?; + let config: AppConfig = serde_yaml::from_str(&content) + .with_context(|| "Failed to parse config YAML")?; + Ok(config) + } + + /// Create a domain-specific Config from AppConfig and DomainConfig + pub fn create_domain_config(&self, domain: &crate::acme::domain_reader::DomainConfig, https_path: PathBuf) -> Config { + let mut domain_https_path = https_path.clone(); + domain_https_path.push(&domain.domain); + + let mut cert_path = domain_https_path.clone(); + cert_path.push("cert.pem"); + let mut key_path = domain_https_path.clone(); + key_path.push("key.pem"); + let static_path = domain_https_path.clone(); + + Config { + https_path: domain_https_path, + cert_path, + key_path, + static_path, + opts: ConfigOpts { + ip: self.server.ip.clone(), + port: self.server.port, + domain: domain.domain.clone(), + email: domain.email.clone().or_else(|| Some(self.acme.email.clone())), + https_dns: domain.dns, + development: self.acme.development, + dns_lookup_max_attempts: Some(self.acme.dns_lookup.max_attempts), + dns_lookup_delay_seconds: Some(self.acme.dns_lookup.delay_seconds), + storage_type: Some(self.storage.storage_type.clone()), + redis_url: self.storage.redis_url.clone(), + lock_ttl_seconds: Some(self.storage.lock_ttl_seconds), + redis_ssl: self.storage.redis_ssl.clone(), + challenge_max_ttl_seconds: Some(self.acme.challenge_max_ttl_seconds), + }, + } + } +} + + diff --git a/src/acme/domain_reader.rs b/src/acme/domain_reader.rs index e2ece1f..868017c 100644 --- a/src/acme/domain_reader.rs +++ b/src/acme/domain_reader.rs @@ -1,620 +1,620 @@ -//! Domain reader that supports multiple sources: file, Redis, and HTTP - -use anyhow::{Context, Result}; -use serde::{Deserialize, Serialize}; -use std::path::PathBuf; -use std::sync::Arc; -use sha2::{Sha256, Digest}; -use notify::{Watcher, RecommendedWatcher, RecursiveMode, EventKind}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DomainConfig { - pub domain: String, - pub email: Option, - pub dns: bool, - pub wildcard: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DomainSourceConfig { - pub source: String, - pub file_path: Option, - pub redis_key: Option, - pub redis_url: Option, - pub redis_ssl: Option, - pub http_url: Option, - pub http_refresh_interval: Option, -} - -/// Domain reader trait -#[async_trait::async_trait] -pub trait DomainReader: Send + Sync { - async fn read_domains(&self) -> Result>; -} - -/// File-based domain reader with file watching and hash-based change detection -pub struct FileDomainReader { - file_path: PathBuf, - cached_domains: Arc, String)>>>, // (domains, hash) -} - -impl FileDomainReader { - pub fn new(file_path: impl Into) -> Self { - let file_path = file_path.into(); - let reader = Self { - file_path: file_path.clone(), - cached_domains: Arc::new(tokio::sync::RwLock::new(None)), - }; - - // Start file watching task - let reader_clone = reader.clone_for_watching(); - tokio::spawn(async move { - reader_clone.start_watching().await; - }); - - reader - } - - /// Create a clone for the watching task - fn clone_for_watching(&self) -> FileDomainReaderWatching { - FileDomainReaderWatching { - file_path: self.file_path.clone(), - cached_domains: Arc::clone(&self.cached_domains), - } - } - - /// Calculate SHA256 hash of content - fn calculate_hash(content: &str) -> String { - let mut hasher = Sha256::new(); - hasher.update(content.as_bytes()); - format!("{:x}", hasher.finalize()) - } - - /// Fetch domains from file - async fn fetch_domains(&self) -> Result<(Vec, String)> { - let content = tokio::fs::read_to_string(&self.file_path) - .await - .with_context(|| format!("Failed to read domains file: {:?}", self.file_path))?; - - let hash = Self::calculate_hash(&content); - - let domains: Vec = serde_json::from_str(&content) - .with_context(|| format!("Failed to parse domains JSON: {:?}", self.file_path))?; - - Ok((domains, hash)) - } -} - -/// Internal struct for file watching task -struct FileDomainReaderWatching { - file_path: PathBuf, - cached_domains: Arc, String)>>>, -} - -impl FileDomainReaderWatching { - /// Start watching the file for changes - async fn start_watching(&self) { - // Initial load - if let Err(e) = self.check_and_update().await { - tracing::warn!("Failed to load domains file initially: {}", e); - } - - // Create watcher with std::sync::mpsc (required by notify) - let (tx, rx) = std::sync::mpsc::channel(); - - let mut watcher: RecommendedWatcher = match Watcher::new( - tx, - notify::Config::default() - .with_poll_interval(std::time::Duration::from_secs(1)) - .with_compare_contents(true), - ) { - Ok(w) => w, - Err(e) => { - tracing::error!("Failed to create file watcher: {}", e); - return; - } - }; - - // Watch the parent directory to catch file renames/moves - if let Some(parent) = self.file_path.parent() { - if let Err(e) = watcher.watch(parent, RecursiveMode::NonRecursive) { - tracing::warn!("Failed to watch directory {:?}: {}", parent, e); - } - } - - // Also watch the file directly - if let Err(e) = watcher.watch(&self.file_path, RecursiveMode::NonRecursive) { - tracing::warn!("Failed to watch file {:?}: {}", self.file_path, e); - } - - tracing::info!("Watching domains file: {:?}", self.file_path); - - // Process file events (bridge from sync channel to async) - let file_path = self.file_path.clone(); - let cached_domains = Arc::clone(&self.cached_domains); - - tokio::task::spawn_blocking(move || { - while let Ok(res) = rx.recv() { - match res { - Ok(event) => { - // Check if the event is for our file - if event.paths.iter().any(|p| p == &file_path) { - match event.kind { - EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_) => { - // Use a blocking task to handle the async update - let file_path_clone = file_path.clone(); - let cached_domains_clone = Arc::clone(&cached_domains); - - tokio::spawn(async move { - // Small delay to ensure file write is complete - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - - let content = match tokio::fs::read_to_string(&file_path_clone).await { - Ok(c) => c, - Err(e) => { - tracing::debug!("Failed to read domains file {:?}: {}", file_path_clone, e); - return; - } - }; - - let new_hash = FileDomainReader::calculate_hash(&content); - - // Check if hash changed - { - let cache = cached_domains_clone.read().await; - if let Some((_, old_hash)) = cache.as_ref() { - if *old_hash == new_hash { - return; // No change - } - } - } - - // Parse and update cache - let domains: Vec = match serde_json::from_str(&content) { - Ok(d) => d, - Err(e) => { - tracing::warn!("Failed to parse domains JSON from file {:?}: {}", file_path_clone, e); - return; - } - }; - - { - let mut cache = cached_domains_clone.write().await; - *cache = Some((domains, new_hash)); - } - - tracing::info!("Domains file changed (hash updated), cache refreshed"); - }); - } - _ => {} - } - } - } - Err(e) => { - tracing::warn!("File watcher error: {}", e); - } - } - } - }); - } - - /// Check file and update cache if content changed - async fn check_and_update(&self) -> Result { - let content = match tokio::fs::read_to_string(&self.file_path).await { - Ok(c) => c, - Err(e) => { - tracing::debug!("Failed to read domains file {:?}: {}", self.file_path, e); - return Ok(false); - } - }; - - let new_hash = FileDomainReader::calculate_hash(&content); - - // Check if hash changed - { - let cache = self.cached_domains.read().await; - if let Some((_, old_hash)) = cache.as_ref() { - if *old_hash == new_hash { - return Ok(false); // No change - } - } - } - - // Parse and update cache - let domains: Vec = match serde_json::from_str(&content) { - Ok(d) => d, - Err(e) => { - tracing::warn!("Failed to parse domains JSON from file {:?}: {}", self.file_path, e); - return Ok(false); - } - }; - - { - let mut cache = self.cached_domains.write().await; - *cache = Some((domains, new_hash)); - } - - Ok(true) // Changed - } -} - -#[async_trait::async_trait] -impl DomainReader for FileDomainReader { - async fn read_domains(&self) -> Result> { - // First, try to get from cache - { - let cache = self.cached_domains.read().await; - if let Some((domains, _)) = cache.as_ref() { - return Ok(domains.clone()); - } - } - - // Cache is empty, fetch from file - let (domains, hash) = self.fetch_domains().await?; - - // Update cache - { - let mut cache = self.cached_domains.write().await; - *cache = Some((domains.clone(), hash)); - } - - Ok(domains) - } -} - -/// Redis-based domain reader with polling and hash-based change detection -pub struct RedisDomainReader { - redis_key: String, - redis_url: String, - redis_ssl: Option, - cached_domains: Arc, String)>>>, // (domains, hash) -} - -impl RedisDomainReader { - pub fn new(redis_key: String, redis_url: Option, redis_ssl: Option) -> Self { - let reader = Self { - redis_key: redis_key.clone(), - redis_url: redis_url - .clone() - .or_else(|| std::env::var("REDIS_URL").ok()) - .unwrap_or_else(|| "redis://127.0.0.1:6379".to_string()), - redis_ssl: redis_ssl.clone(), - cached_domains: Arc::new(tokio::sync::RwLock::new(None)), - }; - - // Start background polling task - let reader_clone = reader.clone_for_polling(); - tokio::spawn(async move { - reader_clone.start_polling().await; - }); - - reader - } - - /// Create a clone for the polling task (only the necessary fields) - fn clone_for_polling(&self) -> RedisDomainReaderPolling { - RedisDomainReaderPolling { - redis_key: self.redis_key.clone(), - redis_url: self.redis_url.clone(), - redis_ssl: self.redis_ssl.clone(), - cached_domains: Arc::clone(&self.cached_domains), - } - } - - /// Calculate SHA256 hash of content - fn calculate_hash(content: &str) -> String { - let mut hasher = Sha256::new(); - hasher.update(content.as_bytes()); - format!("{:x}", hasher.finalize()) - } - - /// Create Redis client with optional SSL configuration - fn create_redis_client(&self) -> Result { - if let Some(ssl_config) = &self.redis_ssl { - Self::create_client_with_ssl(&self.redis_url, ssl_config) - } else { - redis::Client::open(self.redis_url.as_str()) - .with_context(|| format!("Failed to connect to Redis at {}", self.redis_url)) - } - } - - /// Create Redis client with custom SSL/TLS configuration (static method for use in polling) - pub(crate) fn create_client_with_ssl(redis_url: &str, ssl_config: &crate::acme::config::RedisSslConfig) -> Result { - use native_tls::{Certificate, Identity, TlsConnector}; - - // Build TLS connector with custom certificates - let mut tls_builder = TlsConnector::builder(); - - // Load CA certificate if provided - if let Some(ca_cert_path) = &ssl_config.ca_cert_path { - let ca_cert_data = std::fs::read(ca_cert_path) - .with_context(|| format!("Failed to read CA certificate from {}", ca_cert_path))?; - let ca_cert = Certificate::from_pem(&ca_cert_data) - .with_context(|| format!("Failed to parse CA certificate from {}", ca_cert_path))?; - tls_builder.add_root_certificate(ca_cert); - tracing::info!("Loaded CA certificate from {}", ca_cert_path); - } - - // Load client certificate and key if provided - if let (Some(client_cert_path), Some(client_key_path)) = (&ssl_config.client_cert_path, &ssl_config.client_key_path) { - let client_cert_data = std::fs::read(client_cert_path) - .with_context(|| format!("Failed to read client certificate from {}", client_cert_path))?; - let client_key_data = std::fs::read(client_key_path) - .with_context(|| format!("Failed to read client key from {}", client_key_path))?; - - // Try to create identity from PEM format (cert + key) - let identity = Identity::from_pkcs8(&client_cert_data, &client_key_data) - .or_else(|_| { - // Try PEM format if PKCS#8 fails - Identity::from_pkcs12(&client_cert_data, "") - }) - .or_else(|_| { - // Try loading as separate PEM files - // Combine cert and key into a single PEM - let mut combined = client_cert_data.clone(); - combined.extend_from_slice(b"\n"); - combined.extend_from_slice(&client_key_data); - Identity::from_pkcs12(&combined, "") - }) - .with_context(|| format!("Failed to parse client certificate/key from {} and {}. Supported formats: PKCS#8, PKCS#12, or PEM", client_cert_path, client_key_path))?; - tls_builder.identity(identity); - tracing::info!("Loaded client certificate from {} and key from {}", client_cert_path, client_key_path); - } - - // Configure certificate verification - if ssl_config.insecure { - tls_builder.danger_accept_invalid_certs(true); - tls_builder.danger_accept_invalid_hostnames(true); - tracing::warn!("Redis SSL: Certificate verification disabled (insecure mode)"); - } - - let _tls_connector = tls_builder.build() - .with_context(|| "Failed to build TLS connector")?; - - // Note: The redis crate with tokio-native-tls-comp uses native-tls internally, - // but doesn't expose a way to pass a custom TlsConnector. However, when using - // rediss:// URLs, it will use the system trust store. For custom CA certificates, - // we need to add them to the system trust store or use a workaround. - - let client = redis::Client::open(redis_url) - .with_context(|| format!("Failed to create Redis client with SSL config"))?; - - Ok(client) - } - - /// Fetch domains from Redis - async fn fetch_domains(&self) -> Result<(Vec, String)> { - use redis::AsyncCommands; - - let client = self.create_redis_client()?; - - use redis::aio::ConnectionManager; - let mut conn = ConnectionManager::new(client) - .await - .with_context(|| "Failed to get Redis connection")?; - - let content: String = conn.get(&self.redis_key).await - .with_context(|| format!("Failed to read domains from Redis key: {}", self.redis_key))?; - - let hash = Self::calculate_hash(&content); - - let domains: Vec = serde_json::from_str(&content) - .with_context(|| format!("Failed to parse domains JSON from Redis"))?; - - Ok((domains, hash)) - } -} - -/// Internal struct for polling task (avoids circular references) -struct RedisDomainReaderPolling { - redis_key: String, - redis_url: String, - redis_ssl: Option, - cached_domains: Arc, String)>>>, -} - -impl RedisDomainReaderPolling { - /// Create Redis client with optional SSL configuration - fn create_redis_client(&self) -> Result { - if let Some(ssl_config) = &self.redis_ssl { - RedisDomainReader::create_client_with_ssl(&self.redis_url, ssl_config) - } else { - redis::Client::open(self.redis_url.as_str()) - .with_context(|| format!("Failed to connect to Redis at {}", self.redis_url)) - } - } - - /// Start polling Redis every 5 seconds - async fn start_polling(&self) { - let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(5)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - - loop { - interval.tick().await; - - match self.check_and_update().await { - Ok(changed) => { - if changed { - tracing::info!("Redis domains changed (hash updated), cache refreshed"); - } - } - Err(e) => { - tracing::warn!("Failed to check Redis for domain changes: {}", e); - } - } - } - } - - /// Check Redis and update cache if content changed - async fn check_and_update(&self) -> Result { - use redis::AsyncCommands; - - let client = match self.create_redis_client() { - Ok(c) => c, - Err(e) => { - tracing::debug!("Failed to create Redis client: {}", e); - return Ok(false); - } - }; - - use redis::aio::ConnectionManager; - let mut conn = match ConnectionManager::new(client).await { - Ok(c) => c, - Err(e) => { - tracing::debug!("Failed to get Redis connection: {}", e); - return Ok(false); - } - }; - - let content: String = match conn.get(&self.redis_key).await { - Ok(c) => c, - Err(e) => { - tracing::debug!("Failed to read from Redis key {}: {}", self.redis_key, e); - return Ok(false); - } - }; - - let new_hash = RedisDomainReader::calculate_hash(&content); - - // Check if hash changed - { - let cache = self.cached_domains.read().await; - if let Some((_, old_hash)) = cache.as_ref() { - if *old_hash == new_hash { - return Ok(false); // No change - } - } - } - - // Parse and update cache - let domains: Vec = match serde_json::from_str(&content) { - Ok(d) => d, - Err(e) => { - tracing::warn!("Failed to parse domains JSON from Redis: {}", e); - return Ok(false); - } - }; - - { - let mut cache = self.cached_domains.write().await; - *cache = Some((domains, new_hash)); - } - - Ok(true) // Changed - } -} - -#[async_trait::async_trait] -impl DomainReader for RedisDomainReader { - async fn read_domains(&self) -> Result> { - // First, try to get from cache - { - let cache = self.cached_domains.read().await; - if let Some((domains, _)) = cache.as_ref() { - return Ok(domains.clone()); - } - } - - // Cache is empty, fetch from Redis - let (domains, hash) = self.fetch_domains().await?; - - // Update cache - { - let mut cache = self.cached_domains.write().await; - *cache = Some((domains.clone(), hash)); - } - - Ok(domains) - } -} - -/// HTTP-based domain reader -pub struct HttpDomainReader { - url: String, - refresh_interval: u64, - cached_domains: tokio::sync::RwLock, chrono::DateTime)>>, -} - -impl HttpDomainReader { - pub fn new(url: String, refresh_interval: u64) -> Self { - Self { - url, - refresh_interval, - cached_domains: tokio::sync::RwLock::new(None), - } - } - - async fn fetch_domains(&self) -> Result> { - let response = reqwest::get(&self.url).await - .with_context(|| format!("Failed to fetch domains from {}", self.url))?; - - let content = response.text().await - .with_context(|| format!("Failed to read response from {}", self.url))?; - - let domains: Vec = serde_json::from_str(&content) - .with_context(|| format!("Failed to parse domains JSON from HTTP response"))?; - - Ok(domains) - } -} - -#[async_trait::async_trait] -impl DomainReader for HttpDomainReader { - async fn read_domains(&self) -> Result> { - let now = chrono::Utc::now(); - - // Check cache - { - let cache = self.cached_domains.read().await; - if let Some((domains, cached_at)) = cache.as_ref() { - let age = now - *cached_at; - if age.num_seconds() < self.refresh_interval as i64 { - return Ok(domains.clone()); - } - } - } - - // Fetch fresh data - let domains = self.fetch_domains().await?; - - // Update cache - { - let mut cache = self.cached_domains.write().await; - *cache = Some((domains.clone(), now)); - } - - Ok(domains) - } -} - -/// Factory for creating domain readers -pub struct DomainReaderFactory; - -impl DomainReaderFactory { - pub fn create(config: &DomainSourceConfig) -> Result> { - match config.source.as_str() { - "file" => { - let file_path = config.file_path.as_ref() - .ok_or_else(|| anyhow::anyhow!("file_path is required for file source"))?; - Ok(Box::new(FileDomainReader::new(file_path))) - } - "redis" => { - let redis_key = config.redis_key.as_ref() - .ok_or_else(|| anyhow::anyhow!("redis_key is required for redis source"))? - .clone(); - let redis_url = config.redis_url.clone(); - let redis_ssl = config.redis_ssl.clone(); - Ok(Box::new(RedisDomainReader::new(redis_key, redis_url, redis_ssl))) - } - "http" => { - let url = config.http_url.as_ref() - .ok_or_else(|| anyhow::anyhow!("http_url is required for http source"))? - .clone(); - let refresh_interval = config.http_refresh_interval.unwrap_or(300); - Ok(Box::new(HttpDomainReader::new(url, refresh_interval))) - } - _ => Err(anyhow::anyhow!("Unknown domain source: {}", config.source)), - } - } -} - +//! Domain reader that supports multiple sources: file, Redis, and HTTP + +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use std::sync::Arc; +use sha2::{Sha256, Digest}; +use notify::{Watcher, RecommendedWatcher, RecursiveMode, EventKind}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DomainConfig { + pub domain: String, + pub email: Option, + pub dns: bool, + pub wildcard: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DomainSourceConfig { + pub source: String, + pub file_path: Option, + pub redis_key: Option, + pub redis_url: Option, + pub redis_ssl: Option, + pub http_url: Option, + pub http_refresh_interval: Option, +} + +/// Domain reader trait +#[async_trait::async_trait] +pub trait DomainReader: Send + Sync { + async fn read_domains(&self) -> Result>; +} + +/// File-based domain reader with file watching and hash-based change detection +pub struct FileDomainReader { + file_path: PathBuf, + cached_domains: Arc, String)>>>, // (domains, hash) +} + +impl FileDomainReader { + pub fn new(file_path: impl Into) -> Self { + let file_path = file_path.into(); + let reader = Self { + file_path: file_path.clone(), + cached_domains: Arc::new(tokio::sync::RwLock::new(None)), + }; + + // Start file watching task + let reader_clone = reader.clone_for_watching(); + tokio::spawn(async move { + reader_clone.start_watching().await; + }); + + reader + } + + /// Create a clone for the watching task + fn clone_for_watching(&self) -> FileDomainReaderWatching { + FileDomainReaderWatching { + file_path: self.file_path.clone(), + cached_domains: Arc::clone(&self.cached_domains), + } + } + + /// Calculate SHA256 hash of content + fn calculate_hash(content: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(content.as_bytes()); + format!("{:x}", hasher.finalize()) + } + + /// Fetch domains from file + async fn fetch_domains(&self) -> Result<(Vec, String)> { + let content = tokio::fs::read_to_string(&self.file_path) + .await + .with_context(|| format!("Failed to read domains file: {:?}", self.file_path))?; + + let hash = Self::calculate_hash(&content); + + let domains: Vec = serde_json::from_str(&content) + .with_context(|| format!("Failed to parse domains JSON: {:?}", self.file_path))?; + + Ok((domains, hash)) + } +} + +/// Internal struct for file watching task +struct FileDomainReaderWatching { + file_path: PathBuf, + cached_domains: Arc, String)>>>, +} + +impl FileDomainReaderWatching { + /// Start watching the file for changes + async fn start_watching(&self) { + // Initial load + if let Err(e) = self.check_and_update().await { + tracing::warn!("Failed to load domains file initially: {}", e); + } + + // Create watcher with std::sync::mpsc (required by notify) + let (tx, rx) = std::sync::mpsc::channel(); + + let mut watcher: RecommendedWatcher = match Watcher::new( + tx, + notify::Config::default() + .with_poll_interval(std::time::Duration::from_secs(1)) + .with_compare_contents(true), + ) { + Ok(w) => w, + Err(e) => { + tracing::error!("Failed to create file watcher: {}", e); + return; + } + }; + + // Watch the parent directory to catch file renames/moves + if let Some(parent) = self.file_path.parent() { + if let Err(e) = watcher.watch(parent, RecursiveMode::NonRecursive) { + tracing::warn!("Failed to watch directory {:?}: {}", parent, e); + } + } + + // Also watch the file directly + if let Err(e) = watcher.watch(&self.file_path, RecursiveMode::NonRecursive) { + tracing::warn!("Failed to watch file {:?}: {}", self.file_path, e); + } + + tracing::info!("Watching domains file: {:?}", self.file_path); + + // Process file events (bridge from sync channel to async) + let file_path = self.file_path.clone(); + let cached_domains = Arc::clone(&self.cached_domains); + + tokio::task::spawn_blocking(move || { + while let Ok(res) = rx.recv() { + match res { + Ok(event) => { + // Check if the event is for our file + if event.paths.iter().any(|p| p == &file_path) { + match event.kind { + EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_) => { + // Use a blocking task to handle the async update + let file_path_clone = file_path.clone(); + let cached_domains_clone = Arc::clone(&cached_domains); + + tokio::spawn(async move { + // Small delay to ensure file write is complete + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let content = match tokio::fs::read_to_string(&file_path_clone).await { + Ok(c) => c, + Err(e) => { + tracing::debug!("Failed to read domains file {:?}: {}", file_path_clone, e); + return; + } + }; + + let new_hash = FileDomainReader::calculate_hash(&content); + + // Check if hash changed + { + let cache = cached_domains_clone.read().await; + if let Some((_, old_hash)) = cache.as_ref() { + if *old_hash == new_hash { + return; // No change + } + } + } + + // Parse and update cache + let domains: Vec = match serde_json::from_str(&content) { + Ok(d) => d, + Err(e) => { + tracing::warn!("Failed to parse domains JSON from file {:?}: {}", file_path_clone, e); + return; + } + }; + + { + let mut cache = cached_domains_clone.write().await; + *cache = Some((domains, new_hash)); + } + + tracing::info!("Domains file changed (hash updated), cache refreshed"); + }); + } + _ => {} + } + } + } + Err(e) => { + tracing::warn!("File watcher error: {}", e); + } + } + } + }); + } + + /// Check file and update cache if content changed + async fn check_and_update(&self) -> Result { + let content = match tokio::fs::read_to_string(&self.file_path).await { + Ok(c) => c, + Err(e) => { + tracing::debug!("Failed to read domains file {:?}: {}", self.file_path, e); + return Ok(false); + } + }; + + let new_hash = FileDomainReader::calculate_hash(&content); + + // Check if hash changed + { + let cache = self.cached_domains.read().await; + if let Some((_, old_hash)) = cache.as_ref() { + if *old_hash == new_hash { + return Ok(false); // No change + } + } + } + + // Parse and update cache + let domains: Vec = match serde_json::from_str(&content) { + Ok(d) => d, + Err(e) => { + tracing::warn!("Failed to parse domains JSON from file {:?}: {}", self.file_path, e); + return Ok(false); + } + }; + + { + let mut cache = self.cached_domains.write().await; + *cache = Some((domains, new_hash)); + } + + Ok(true) // Changed + } +} + +#[async_trait::async_trait] +impl DomainReader for FileDomainReader { + async fn read_domains(&self) -> Result> { + // First, try to get from cache + { + let cache = self.cached_domains.read().await; + if let Some((domains, _)) = cache.as_ref() { + return Ok(domains.clone()); + } + } + + // Cache is empty, fetch from file + let (domains, hash) = self.fetch_domains().await?; + + // Update cache + { + let mut cache = self.cached_domains.write().await; + *cache = Some((domains.clone(), hash)); + } + + Ok(domains) + } +} + +/// Redis-based domain reader with polling and hash-based change detection +pub struct RedisDomainReader { + redis_key: String, + redis_url: String, + redis_ssl: Option, + cached_domains: Arc, String)>>>, // (domains, hash) +} + +impl RedisDomainReader { + pub fn new(redis_key: String, redis_url: Option, redis_ssl: Option) -> Self { + let reader = Self { + redis_key: redis_key.clone(), + redis_url: redis_url + .clone() + .or_else(|| std::env::var("REDIS_URL").ok()) + .unwrap_or_else(|| "redis://127.0.0.1:6379".to_string()), + redis_ssl: redis_ssl.clone(), + cached_domains: Arc::new(tokio::sync::RwLock::new(None)), + }; + + // Start background polling task + let reader_clone = reader.clone_for_polling(); + tokio::spawn(async move { + reader_clone.start_polling().await; + }); + + reader + } + + /// Create a clone for the polling task (only the necessary fields) + fn clone_for_polling(&self) -> RedisDomainReaderPolling { + RedisDomainReaderPolling { + redis_key: self.redis_key.clone(), + redis_url: self.redis_url.clone(), + redis_ssl: self.redis_ssl.clone(), + cached_domains: Arc::clone(&self.cached_domains), + } + } + + /// Calculate SHA256 hash of content + fn calculate_hash(content: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(content.as_bytes()); + format!("{:x}", hasher.finalize()) + } + + /// Create Redis client with optional SSL configuration + fn create_redis_client(&self) -> Result { + if let Some(ssl_config) = &self.redis_ssl { + Self::create_client_with_ssl(&self.redis_url, ssl_config) + } else { + redis::Client::open(self.redis_url.as_str()) + .with_context(|| format!("Failed to connect to Redis at {}", self.redis_url)) + } + } + + /// Create Redis client with custom SSL/TLS configuration (static method for use in polling) + pub(crate) fn create_client_with_ssl(redis_url: &str, ssl_config: &crate::acme::config::RedisSslConfig) -> Result { + use native_tls::{Certificate, Identity, TlsConnector}; + + // Build TLS connector with custom certificates + let mut tls_builder = TlsConnector::builder(); + + // Load CA certificate if provided + if let Some(ca_cert_path) = &ssl_config.ca_cert_path { + let ca_cert_data = std::fs::read(ca_cert_path) + .with_context(|| format!("Failed to read CA certificate from {}", ca_cert_path))?; + let ca_cert = Certificate::from_pem(&ca_cert_data) + .with_context(|| format!("Failed to parse CA certificate from {}", ca_cert_path))?; + tls_builder.add_root_certificate(ca_cert); + tracing::info!("Loaded CA certificate from {}", ca_cert_path); + } + + // Load client certificate and key if provided + if let (Some(client_cert_path), Some(client_key_path)) = (&ssl_config.client_cert_path, &ssl_config.client_key_path) { + let client_cert_data = std::fs::read(client_cert_path) + .with_context(|| format!("Failed to read client certificate from {}", client_cert_path))?; + let client_key_data = std::fs::read(client_key_path) + .with_context(|| format!("Failed to read client key from {}", client_key_path))?; + + // Try to create identity from PEM format (cert + key) + let identity = Identity::from_pkcs8(&client_cert_data, &client_key_data) + .or_else(|_| { + // Try PEM format if PKCS#8 fails + Identity::from_pkcs12(&client_cert_data, "") + }) + .or_else(|_| { + // Try loading as separate PEM files + // Combine cert and key into a single PEM + let mut combined = client_cert_data.clone(); + combined.extend_from_slice(b"\n"); + combined.extend_from_slice(&client_key_data); + Identity::from_pkcs12(&combined, "") + }) + .with_context(|| format!("Failed to parse client certificate/key from {} and {}. Supported formats: PKCS#8, PKCS#12, or PEM", client_cert_path, client_key_path))?; + tls_builder.identity(identity); + tracing::info!("Loaded client certificate from {} and key from {}", client_cert_path, client_key_path); + } + + // Configure certificate verification + if ssl_config.insecure { + tls_builder.danger_accept_invalid_certs(true); + tls_builder.danger_accept_invalid_hostnames(true); + tracing::warn!("Redis SSL: Certificate verification disabled (insecure mode)"); + } + + let _tls_connector = tls_builder.build() + .with_context(|| "Failed to build TLS connector")?; + + // Note: The redis crate with tokio-native-tls-comp uses native-tls internally, + // but doesn't expose a way to pass a custom TlsConnector. However, when using + // rediss:// URLs, it will use the system trust store. For custom CA certificates, + // we need to add them to the system trust store or use a workaround. + + let client = redis::Client::open(redis_url) + .with_context(|| format!("Failed to create Redis client with SSL config"))?; + + Ok(client) + } + + /// Fetch domains from Redis + async fn fetch_domains(&self) -> Result<(Vec, String)> { + use redis::AsyncCommands; + + let client = self.create_redis_client()?; + + use redis::aio::ConnectionManager; + let mut conn = ConnectionManager::new(client) + .await + .with_context(|| "Failed to get Redis connection")?; + + let content: String = conn.get(&self.redis_key).await + .with_context(|| format!("Failed to read domains from Redis key: {}", self.redis_key))?; + + let hash = Self::calculate_hash(&content); + + let domains: Vec = serde_json::from_str(&content) + .with_context(|| format!("Failed to parse domains JSON from Redis"))?; + + Ok((domains, hash)) + } +} + +/// Internal struct for polling task (avoids circular references) +struct RedisDomainReaderPolling { + redis_key: String, + redis_url: String, + redis_ssl: Option, + cached_domains: Arc, String)>>>, +} + +impl RedisDomainReaderPolling { + /// Create Redis client with optional SSL configuration + fn create_redis_client(&self) -> Result { + if let Some(ssl_config) = &self.redis_ssl { + RedisDomainReader::create_client_with_ssl(&self.redis_url, ssl_config) + } else { + redis::Client::open(self.redis_url.as_str()) + .with_context(|| format!("Failed to connect to Redis at {}", self.redis_url)) + } + } + + /// Start polling Redis every 5 seconds + async fn start_polling(&self) { + let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(5)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + loop { + interval.tick().await; + + match self.check_and_update().await { + Ok(changed) => { + if changed { + tracing::info!("Redis domains changed (hash updated), cache refreshed"); + } + } + Err(e) => { + tracing::warn!("Failed to check Redis for domain changes: {}", e); + } + } + } + } + + /// Check Redis and update cache if content changed + async fn check_and_update(&self) -> Result { + use redis::AsyncCommands; + + let client = match self.create_redis_client() { + Ok(c) => c, + Err(e) => { + tracing::debug!("Failed to create Redis client: {}", e); + return Ok(false); + } + }; + + use redis::aio::ConnectionManager; + let mut conn = match ConnectionManager::new(client).await { + Ok(c) => c, + Err(e) => { + tracing::debug!("Failed to get Redis connection: {}", e); + return Ok(false); + } + }; + + let content: String = match conn.get(&self.redis_key).await { + Ok(c) => c, + Err(e) => { + tracing::debug!("Failed to read from Redis key {}: {}", self.redis_key, e); + return Ok(false); + } + }; + + let new_hash = RedisDomainReader::calculate_hash(&content); + + // Check if hash changed + { + let cache = self.cached_domains.read().await; + if let Some((_, old_hash)) = cache.as_ref() { + if *old_hash == new_hash { + return Ok(false); // No change + } + } + } + + // Parse and update cache + let domains: Vec = match serde_json::from_str(&content) { + Ok(d) => d, + Err(e) => { + tracing::warn!("Failed to parse domains JSON from Redis: {}", e); + return Ok(false); + } + }; + + { + let mut cache = self.cached_domains.write().await; + *cache = Some((domains, new_hash)); + } + + Ok(true) // Changed + } +} + +#[async_trait::async_trait] +impl DomainReader for RedisDomainReader { + async fn read_domains(&self) -> Result> { + // First, try to get from cache + { + let cache = self.cached_domains.read().await; + if let Some((domains, _)) = cache.as_ref() { + return Ok(domains.clone()); + } + } + + // Cache is empty, fetch from Redis + let (domains, hash) = self.fetch_domains().await?; + + // Update cache + { + let mut cache = self.cached_domains.write().await; + *cache = Some((domains.clone(), hash)); + } + + Ok(domains) + } +} + +/// HTTP-based domain reader +pub struct HttpDomainReader { + url: String, + refresh_interval: u64, + cached_domains: tokio::sync::RwLock, chrono::DateTime)>>, +} + +impl HttpDomainReader { + pub fn new(url: String, refresh_interval: u64) -> Self { + Self { + url, + refresh_interval, + cached_domains: tokio::sync::RwLock::new(None), + } + } + + async fn fetch_domains(&self) -> Result> { + let response = reqwest::get(&self.url).await + .with_context(|| format!("Failed to fetch domains from {}", self.url))?; + + let content = response.text().await + .with_context(|| format!("Failed to read response from {}", self.url))?; + + let domains: Vec = serde_json::from_str(&content) + .with_context(|| format!("Failed to parse domains JSON from HTTP response"))?; + + Ok(domains) + } +} + +#[async_trait::async_trait] +impl DomainReader for HttpDomainReader { + async fn read_domains(&self) -> Result> { + let now = chrono::Utc::now(); + + // Check cache + { + let cache = self.cached_domains.read().await; + if let Some((domains, cached_at)) = cache.as_ref() { + let age = now - *cached_at; + if age.num_seconds() < self.refresh_interval as i64 { + return Ok(domains.clone()); + } + } + } + + // Fetch fresh data + let domains = self.fetch_domains().await?; + + // Update cache + { + let mut cache = self.cached_domains.write().await; + *cache = Some((domains.clone(), now)); + } + + Ok(domains) + } +} + +/// Factory for creating domain readers +pub struct DomainReaderFactory; + +impl DomainReaderFactory { + pub fn create(config: &DomainSourceConfig) -> Result> { + match config.source.as_str() { + "file" => { + let file_path = config.file_path.as_ref() + .ok_or_else(|| anyhow::anyhow!("file_path is required for file source"))?; + Ok(Box::new(FileDomainReader::new(file_path))) + } + "redis" => { + let redis_key = config.redis_key.as_ref() + .ok_or_else(|| anyhow::anyhow!("redis_key is required for redis source"))? + .clone(); + let redis_url = config.redis_url.clone(); + let redis_ssl = config.redis_ssl.clone(); + Ok(Box::new(RedisDomainReader::new(redis_key, redis_url, redis_ssl))) + } + "http" => { + let url = config.http_url.as_ref() + .ok_or_else(|| anyhow::anyhow!("http_url is required for http source"))? + .clone(); + let refresh_interval = config.http_refresh_interval.unwrap_or(300); + Ok(Box::new(HttpDomainReader::new(url, refresh_interval))) + } + _ => Err(anyhow::anyhow!("Unknown domain source: {}", config.source)), + } + } +} + diff --git a/src/acme/embedded.rs b/src/acme/embedded.rs index 040133d..689e914 100644 --- a/src/acme/embedded.rs +++ b/src/acme/embedded.rs @@ -1,418 +1,418 @@ -//! Embedded ACME server that integrates with the main synapse application -//! Reads domains from upstreams.yaml and manages certificates - -use crate::acme::domain_reader::{DomainConfig, DomainReader}; -use crate::acme::{request_cert, should_renew_certs_check, StorageFactory}; -use crate::acme::upstreams_reader::UpstreamsDomainReader; -use anyhow::{Context, Result}; -use std::path::PathBuf; -use std::sync::Arc; -use tokio::sync::RwLock; -use tracing::{info, warn}; -use actix_web::{App, HttpServer, HttpResponse, web, Responder}; -use serde::Serialize; - -#[derive(Debug, Clone, Serialize)] -pub struct EmbeddedAcmeConfig { - /// Port for ACME server (e.g., 9180) - pub port: u16, - /// IP address to bind (default: 127.0.0.1) - #[serde(default = "default_bind_ip")] - pub bind_ip: String, - /// Path to upstreams.yaml file - pub upstreams_path: PathBuf, - /// Email for ACME account - pub email: String, - /// Storage path for certificates - pub storage_path: PathBuf, - /// Storage type: "file" or "redis" (optional, defaults based on redis_url) - pub storage_type: Option, - /// Use development/staging ACME server - #[serde(default)] - pub development: bool, - /// Redis URL for storage (optional) - pub redis_url: Option, - /// Redis SSL config (optional) - pub redis_ssl: Option, -} - -pub struct EmbeddedAcmeServer { - config: EmbeddedAcmeConfig, - domain_reader: Arc>>>, -} - -impl EmbeddedAcmeServer { - pub fn new(config: EmbeddedAcmeConfig) -> Self { - Self { - config, - domain_reader: Arc::new(RwLock::new(None)), - } - } - - /// Initialize the domain reader from upstreams - pub async fn init_domain_reader(&self) -> Result<()> { - let reader: Arc = Arc::new( - UpstreamsDomainReader::new( - self.config.upstreams_path.clone(), - Some(self.config.email.clone()), - ) - ); - - let mut domain_reader = self.domain_reader.write().await; - *domain_reader = Some(reader); - - Ok(()) - } - - /// Start the embedded ACME HTTP server - pub async fn start_server(&self) -> Result<()> { - let address = format!("{}:{}", self.config.bind_ip, self.config.port); - info!("Starting embedded ACME server at {}", address); - - // Ensure challenge directory exists - let mut challenge_path = self.config.storage_path.clone(); - challenge_path.push("well-known"); - challenge_path.push("acme-challenge"); - tokio::fs::create_dir_all(&challenge_path).await - .with_context(|| format!("Failed to create challenge directory: {:?}", challenge_path))?; - - let challenge_path_clone = challenge_path.clone(); - let domain_reader_clone = Arc::clone(&self.domain_reader); - let config_clone = self.config.clone(); - - let server = HttpServer::new(move || { - App::new() - .app_data(web::Data::new(config_clone.clone())) - .app_data(web::Data::new(domain_reader_clone.clone())) - .service( - // Serve ACME challenges - actix_files::Files::new("/.well-known/acme-challenge", challenge_path_clone.clone()) - .prefer_utf8(true), - ) - .route( - "/cert/expiration", - web::get().to(check_all_certs_expiration_handler), - ) - .route( - "/cert/expiration/{domain}", - web::get().to(check_cert_expiration_handler), - ) - .route( - "/cert/renew/{domain}", - web::post().to(renew_cert_handler), - ) - .default_service(web::route().to(|| async { - HttpResponse::NotFound().body("Not Found") - })) - }) - .bind(&address) - .with_context(|| format!("Failed to bind ACME server to {}", address))?; - - info!("Embedded ACME server started at {}", address); - server.run().await - .with_context(|| "ACME server error")?; - - Ok(()) - } - - /// Process certificates for all domains - pub async fn process_certificates(&self) -> Result<()> { - let domain_reader = self.domain_reader.read().await; - let reader = domain_reader.as_ref() - .ok_or_else(|| anyhow::anyhow!("Domain reader not initialized"))?; - - let domains = reader.read_domains().await - .context("Failed to read domains")?; - - info!("Processing {} domain(s) for certificate management", domains.len()); - - for domain_config in domains { - let domain_cfg = self.create_domain_config(&domain_config)?; - - // Check if certificate needs renewal - if should_renew_certs_check(&domain_cfg).await? { - info!("Requesting new certificate for {}...", domain_config.domain); - if let Err(e) = request_cert(&domain_cfg).await { - warn!("Failed to request certificate for {}: {}", domain_config.domain, e); - } else { - info!("Certificate obtained successfully for {}!", domain_config.domain); - } - } else { - info!("Certificate is still valid for {}", domain_config.domain); - } - } - - Ok(()) - } - - fn create_domain_config(&self, domain: &DomainConfig) -> Result { - let mut domain_https_path = self.config.storage_path.clone(); - domain_https_path.push(&domain.domain); - - let mut cert_path = domain_https_path.clone(); - cert_path.push("cert.pem"); - let mut key_path = domain_https_path.clone(); - key_path.push("key.pem"); - let static_path = domain_https_path.clone(); - - // Format domain for ACME order - // If wildcard is true, ensure domain has *. prefix for ACME order - let acme_domain = if domain.wildcard && !domain.domain.starts_with("*.") { - // Extract base domain (domain + TLD, e.g., arxignis.dev from david-proxytest2.arxignis.dev) - // Split by '.' and take the last two parts - let parts: Vec<&str> = domain.domain.split('.').collect(); - if parts.len() >= 2 { - let base = parts[parts.len() - 2..].join("."); - format!("*.{}", base) - } else { - // Fallback: just add *. prefix - format!("*.{}", domain.domain) - } - } else { - domain.domain.clone() - }; - - Ok(crate::acme::Config { - https_path: domain_https_path, - cert_path, - key_path, - static_path, - opts: crate::acme::ConfigOpts { - ip: self.config.bind_ip.clone(), - port: self.config.port, - domain: acme_domain, - email: Some(domain.email.clone().unwrap_or_else(|| self.config.email.clone())), - https_dns: domain.dns, - development: self.config.development, - dns_lookup_max_attempts: Some(100), - dns_lookup_delay_seconds: Some(10), - storage_type: { - // Always use Redis (storage_type option is kept for compatibility but always uses Redis) - let storage_type = Some("redis".to_string()); - tracing::info!("Domain {}: Using storage type: 'redis'", domain.domain); - storage_type - }, - redis_url: self.config.redis_url.clone(), - lock_ttl_seconds: Some(900), - redis_ssl: self.config.redis_ssl.clone(), - challenge_max_ttl_seconds: Some(3600), - }, - }) - } -} - -/// HTTP handler for checking expiration of all domains -async fn check_all_certs_expiration_handler( - config: web::Data, - domain_reader: web::Data>>>>, -) -> impl Responder { - let reader = domain_reader.read().await; - let reader_ref = match reader.as_ref() { - Some(r) => r, - None => { - return HttpResponse::InternalServerError().json(serde_json::json!({ - "error": "Domain reader not initialized" - })); - } - }; - - let domains = match reader_ref.read_domains().await { - Ok(d) => d, - Err(e) => { - warn!("Error reading domains: {}", e); - return HttpResponse::InternalServerError().json(serde_json::json!({ - "error": format!("Failed to read domains: {}", e) - })); - } - }; - - let mut results = Vec::new(); - for domain_config in domains { - let domain_cfg = match create_domain_config_for_handler(&config, &domain_config) { - Ok(cfg) => cfg, - Err(e) => { - warn!("Error creating domain config for {}: {}", domain_config.domain, e); - continue; - } - }; - - let storage = match StorageFactory::create_default(&domain_cfg) { - Ok(s) => s, - Err(e) => { - warn!("Error creating storage for {}: {}", domain_config.domain, e); - continue; - } - }; - - let exists = storage.cert_exists().await; - results.push(serde_json::json!({ - "domain": domain_config.domain, - "exists": exists, - })); - } - - HttpResponse::Ok().json(results) -} - -/// HTTP handler for checking expiration of a specific domain -async fn check_cert_expiration_handler( - config: web::Data, - domain_reader: web::Data>>>>, - path: web::Path, -) -> impl Responder { - let domain = path.into_inner(); - - let reader = domain_reader.read().await; - let reader_ref = match reader.as_ref() { - Some(r) => r, - None => { - return HttpResponse::InternalServerError().json(serde_json::json!({ - "error": "Domain reader not initialized" - })); - } - }; - - let domains = match reader_ref.read_domains().await { - Ok(d) => d, - Err(e) => { - warn!("Error reading domains: {}", e); - return HttpResponse::InternalServerError().json(serde_json::json!({ - "error": format!("Failed to read domains: {}", e) - })); - } - }; - - let domain_config = match domains.iter().find(|d| d.domain == domain) { - Some(d) => d.clone(), - None => { - return HttpResponse::NotFound().json(serde_json::json!({ - "error": format!("Domain {} not found", domain) - })); - } - }; - - let domain_cfg = match create_domain_config_for_handler(&config, &domain_config) { - Ok(cfg) => cfg, - Err(e) => { - return HttpResponse::InternalServerError().json(serde_json::json!({ - "error": format!("Failed to create domain config: {}", e) - })); - } - }; - - let storage = match StorageFactory::create_default(&domain_cfg) { - Ok(s) => s, - Err(e) => { - return HttpResponse::InternalServerError().json(serde_json::json!({ - "error": format!("Failed to create storage: {}", e) - })); - } - }; - - let exists = storage.cert_exists().await; - HttpResponse::Ok().json(serde_json::json!({ - "domain": domain, - "exists": exists, - })) -} - -/// HTTP handler for renewing a certificate -async fn renew_cert_handler( - config: web::Data, - domain_reader: web::Data>>>>, - path: web::Path, -) -> impl Responder { - let domain = path.into_inner(); - - let reader = domain_reader.read().await; - let reader_ref = match reader.as_ref() { - Some(r) => r, - None => { - return HttpResponse::InternalServerError().json(serde_json::json!({ - "error": "Domain reader not initialized" - })); - } - }; - - let domains = match reader_ref.read_domains().await { - Ok(d) => d, - Err(e) => { - warn!("Error reading domains: {}", e); - return HttpResponse::InternalServerError().json(serde_json::json!({ - "error": format!("Failed to read domains: {}", e) - })); - } - }; - - let domain_config = match domains.iter().find(|d| d.domain == domain) { - Some(d) => d.clone(), - None => { - return HttpResponse::NotFound().json(serde_json::json!({ - "error": format!("Domain {} not found", domain) - })); - } - }; - - let domain_cfg = match create_domain_config_for_handler(&config, &domain_config) { - Ok(cfg) => cfg, - Err(e) => { - return HttpResponse::InternalServerError().json(serde_json::json!({ - "error": format!("Failed to create domain config: {}", e) - })); - } - }; - - // Spawn renewal in background - let domain_config_clone = domain_config.clone(); - tokio::spawn(async move { - if let Err(e) = request_cert(&domain_cfg).await { - warn!("Error renewing certificate for {}: {}", domain_config_clone.domain, e); - } else { - info!("Certificate renewed successfully for {}!", domain_config_clone.domain); - } - }); - - HttpResponse::Ok().json(serde_json::json!({ - "message": format!("Certificate renewal started for {}", domain), - })) -} - -fn create_domain_config_for_handler( - config: &EmbeddedAcmeConfig, - domain: &DomainConfig, -) -> Result { - let mut domain_https_path = config.storage_path.clone(); - domain_https_path.push(&domain.domain); - - let mut cert_path = domain_https_path.clone(); - cert_path.push("cert.pem"); - let mut key_path = domain_https_path.clone(); - key_path.push("key.pem"); - let static_path = domain_https_path.clone(); - - Ok(crate::acme::Config { - https_path: domain_https_path, - cert_path, - key_path, - static_path, - opts: crate::acme::ConfigOpts { - ip: config.bind_ip.clone(), - port: config.port, - domain: domain.domain.clone(), - email: Some(domain.email.clone().unwrap_or_else(|| config.email.clone())), - https_dns: domain.dns, - development: config.development, - dns_lookup_max_attempts: Some(100), - dns_lookup_delay_seconds: Some(10), - storage_type: { - // Always use Redis (storage_type option is kept for compatibility but always uses Redis) - Some("redis".to_string()) - }, - redis_url: config.redis_url.clone(), - lock_ttl_seconds: Some(900), - redis_ssl: config.redis_ssl.clone(), - challenge_max_ttl_seconds: Some(3600), - }, - }) -} - +//! Embedded ACME server that integrates with the main synapse application +//! Reads domains from upstreams.yaml and manages certificates + +use crate::acme::domain_reader::{DomainConfig, DomainReader}; +use crate::acme::{request_cert, should_renew_certs_check, StorageFactory}; +use crate::acme::upstreams_reader::UpstreamsDomainReader; +use anyhow::{Context, Result}; +use std::path::PathBuf; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::{info, warn}; +use actix_web::{App, HttpServer, HttpResponse, web, Responder}; +use serde::Serialize; + +#[derive(Debug, Clone, Serialize)] +pub struct EmbeddedAcmeConfig { + /// Port for ACME server (e.g., 9180) + pub port: u16, + /// IP address to bind (default: 127.0.0.1) + #[serde(default = "default_bind_ip")] + pub bind_ip: String, + /// Path to upstreams.yaml file + pub upstreams_path: PathBuf, + /// Email for ACME account + pub email: String, + /// Storage path for certificates + pub storage_path: PathBuf, + /// Storage type: "file" or "redis" (optional, defaults based on redis_url) + pub storage_type: Option, + /// Use development/staging ACME server + #[serde(default)] + pub development: bool, + /// Redis URL for storage (optional) + pub redis_url: Option, + /// Redis SSL config (optional) + pub redis_ssl: Option, +} + +pub struct EmbeddedAcmeServer { + config: EmbeddedAcmeConfig, + domain_reader: Arc>>>, +} + +impl EmbeddedAcmeServer { + pub fn new(config: EmbeddedAcmeConfig) -> Self { + Self { + config, + domain_reader: Arc::new(RwLock::new(None)), + } + } + + /// Initialize the domain reader from upstreams + pub async fn init_domain_reader(&self) -> Result<()> { + let reader: Arc = Arc::new( + UpstreamsDomainReader::new( + self.config.upstreams_path.clone(), + Some(self.config.email.clone()), + ) + ); + + let mut domain_reader = self.domain_reader.write().await; + *domain_reader = Some(reader); + + Ok(()) + } + + /// Start the embedded ACME HTTP server + pub async fn start_server(&self) -> Result<()> { + let address = format!("{}:{}", self.config.bind_ip, self.config.port); + info!("Starting embedded ACME server at {}", address); + + // Ensure challenge directory exists + let mut challenge_path = self.config.storage_path.clone(); + challenge_path.push("well-known"); + challenge_path.push("acme-challenge"); + tokio::fs::create_dir_all(&challenge_path).await + .with_context(|| format!("Failed to create challenge directory: {:?}", challenge_path))?; + + let challenge_path_clone = challenge_path.clone(); + let domain_reader_clone = Arc::clone(&self.domain_reader); + let config_clone = self.config.clone(); + + let server = HttpServer::new(move || { + App::new() + .app_data(web::Data::new(config_clone.clone())) + .app_data(web::Data::new(domain_reader_clone.clone())) + .service( + // Serve ACME challenges + actix_files::Files::new("/.well-known/acme-challenge", challenge_path_clone.clone()) + .prefer_utf8(true), + ) + .route( + "/cert/expiration", + web::get().to(check_all_certs_expiration_handler), + ) + .route( + "/cert/expiration/{domain}", + web::get().to(check_cert_expiration_handler), + ) + .route( + "/cert/renew/{domain}", + web::post().to(renew_cert_handler), + ) + .default_service(web::route().to(|| async { + HttpResponse::NotFound().body("Not Found") + })) + }) + .bind(&address) + .with_context(|| format!("Failed to bind ACME server to {}", address))?; + + info!("Embedded ACME server started at {}", address); + server.run().await + .with_context(|| "ACME server error")?; + + Ok(()) + } + + /// Process certificates for all domains + pub async fn process_certificates(&self) -> Result<()> { + let domain_reader = self.domain_reader.read().await; + let reader = domain_reader.as_ref() + .ok_or_else(|| anyhow::anyhow!("Domain reader not initialized"))?; + + let domains = reader.read_domains().await + .context("Failed to read domains")?; + + info!("Processing {} domain(s) for certificate management", domains.len()); + + for domain_config in domains { + let domain_cfg = self.create_domain_config(&domain_config)?; + + // Check if certificate needs renewal + if should_renew_certs_check(&domain_cfg).await? { + info!("Requesting new certificate for {}...", domain_config.domain); + if let Err(e) = request_cert(&domain_cfg).await { + warn!("Failed to request certificate for {}: {}", domain_config.domain, e); + } else { + info!("Certificate obtained successfully for {}!", domain_config.domain); + } + } else { + info!("Certificate is still valid for {}", domain_config.domain); + } + } + + Ok(()) + } + + fn create_domain_config(&self, domain: &DomainConfig) -> Result { + let mut domain_https_path = self.config.storage_path.clone(); + domain_https_path.push(&domain.domain); + + let mut cert_path = domain_https_path.clone(); + cert_path.push("cert.pem"); + let mut key_path = domain_https_path.clone(); + key_path.push("key.pem"); + let static_path = domain_https_path.clone(); + + // Format domain for ACME order + // If wildcard is true, ensure domain has *. prefix for ACME order + let acme_domain = if domain.wildcard && !domain.domain.starts_with("*.") { + // Extract base domain (domain + TLD, e.g., arxignis.dev from david-proxytest2.arxignis.dev) + // Split by '.' and take the last two parts + let parts: Vec<&str> = domain.domain.split('.').collect(); + if parts.len() >= 2 { + let base = parts[parts.len() - 2..].join("."); + format!("*.{}", base) + } else { + // Fallback: just add *. prefix + format!("*.{}", domain.domain) + } + } else { + domain.domain.clone() + }; + + Ok(crate::acme::Config { + https_path: domain_https_path, + cert_path, + key_path, + static_path, + opts: crate::acme::ConfigOpts { + ip: self.config.bind_ip.clone(), + port: self.config.port, + domain: acme_domain, + email: Some(domain.email.clone().unwrap_or_else(|| self.config.email.clone())), + https_dns: domain.dns, + development: self.config.development, + dns_lookup_max_attempts: Some(100), + dns_lookup_delay_seconds: Some(10), + storage_type: { + // Always use Redis (storage_type option is kept for compatibility but always uses Redis) + let storage_type = Some("redis".to_string()); + tracing::info!("Domain {}: Using storage type: 'redis'", domain.domain); + storage_type + }, + redis_url: self.config.redis_url.clone(), + lock_ttl_seconds: Some(900), + redis_ssl: self.config.redis_ssl.clone(), + challenge_max_ttl_seconds: Some(3600), + }, + }) + } +} + +/// HTTP handler for checking expiration of all domains +async fn check_all_certs_expiration_handler( + config: web::Data, + domain_reader: web::Data>>>>, +) -> impl Responder { + let reader = domain_reader.read().await; + let reader_ref = match reader.as_ref() { + Some(r) => r, + None => { + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": "Domain reader not initialized" + })); + } + }; + + let domains = match reader_ref.read_domains().await { + Ok(d) => d, + Err(e) => { + warn!("Error reading domains: {}", e); + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": format!("Failed to read domains: {}", e) + })); + } + }; + + let mut results = Vec::new(); + for domain_config in domains { + let domain_cfg = match create_domain_config_for_handler(&config, &domain_config) { + Ok(cfg) => cfg, + Err(e) => { + warn!("Error creating domain config for {}: {}", domain_config.domain, e); + continue; + } + }; + + let storage = match StorageFactory::create_default(&domain_cfg) { + Ok(s) => s, + Err(e) => { + warn!("Error creating storage for {}: {}", domain_config.domain, e); + continue; + } + }; + + let exists = storage.cert_exists().await; + results.push(serde_json::json!({ + "domain": domain_config.domain, + "exists": exists, + })); + } + + HttpResponse::Ok().json(results) +} + +/// HTTP handler for checking expiration of a specific domain +async fn check_cert_expiration_handler( + config: web::Data, + domain_reader: web::Data>>>>, + path: web::Path, +) -> impl Responder { + let domain = path.into_inner(); + + let reader = domain_reader.read().await; + let reader_ref = match reader.as_ref() { + Some(r) => r, + None => { + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": "Domain reader not initialized" + })); + } + }; + + let domains = match reader_ref.read_domains().await { + Ok(d) => d, + Err(e) => { + warn!("Error reading domains: {}", e); + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": format!("Failed to read domains: {}", e) + })); + } + }; + + let domain_config = match domains.iter().find(|d| d.domain == domain) { + Some(d) => d.clone(), + None => { + return HttpResponse::NotFound().json(serde_json::json!({ + "error": format!("Domain {} not found", domain) + })); + } + }; + + let domain_cfg = match create_domain_config_for_handler(&config, &domain_config) { + Ok(cfg) => cfg, + Err(e) => { + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": format!("Failed to create domain config: {}", e) + })); + } + }; + + let storage = match StorageFactory::create_default(&domain_cfg) { + Ok(s) => s, + Err(e) => { + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": format!("Failed to create storage: {}", e) + })); + } + }; + + let exists = storage.cert_exists().await; + HttpResponse::Ok().json(serde_json::json!({ + "domain": domain, + "exists": exists, + })) +} + +/// HTTP handler for renewing a certificate +async fn renew_cert_handler( + config: web::Data, + domain_reader: web::Data>>>>, + path: web::Path, +) -> impl Responder { + let domain = path.into_inner(); + + let reader = domain_reader.read().await; + let reader_ref = match reader.as_ref() { + Some(r) => r, + None => { + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": "Domain reader not initialized" + })); + } + }; + + let domains = match reader_ref.read_domains().await { + Ok(d) => d, + Err(e) => { + warn!("Error reading domains: {}", e); + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": format!("Failed to read domains: {}", e) + })); + } + }; + + let domain_config = match domains.iter().find(|d| d.domain == domain) { + Some(d) => d.clone(), + None => { + return HttpResponse::NotFound().json(serde_json::json!({ + "error": format!("Domain {} not found", domain) + })); + } + }; + + let domain_cfg = match create_domain_config_for_handler(&config, &domain_config) { + Ok(cfg) => cfg, + Err(e) => { + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": format!("Failed to create domain config: {}", e) + })); + } + }; + + // Spawn renewal in background + let domain_config_clone = domain_config.clone(); + tokio::spawn(async move { + if let Err(e) = request_cert(&domain_cfg).await { + warn!("Error renewing certificate for {}: {}", domain_config_clone.domain, e); + } else { + info!("Certificate renewed successfully for {}!", domain_config_clone.domain); + } + }); + + HttpResponse::Ok().json(serde_json::json!({ + "message": format!("Certificate renewal started for {}", domain), + })) +} + +fn create_domain_config_for_handler( + config: &EmbeddedAcmeConfig, + domain: &DomainConfig, +) -> Result { + let mut domain_https_path = config.storage_path.clone(); + domain_https_path.push(&domain.domain); + + let mut cert_path = domain_https_path.clone(); + cert_path.push("cert.pem"); + let mut key_path = domain_https_path.clone(); + key_path.push("key.pem"); + let static_path = domain_https_path.clone(); + + Ok(crate::acme::Config { + https_path: domain_https_path, + cert_path, + key_path, + static_path, + opts: crate::acme::ConfigOpts { + ip: config.bind_ip.clone(), + port: config.port, + domain: domain.domain.clone(), + email: Some(domain.email.clone().unwrap_or_else(|| config.email.clone())), + https_dns: domain.dns, + development: config.development, + dns_lookup_max_attempts: Some(100), + dns_lookup_delay_seconds: Some(10), + storage_type: { + // Always use Redis (storage_type option is kept for compatibility but always uses Redis) + Some("redis".to_string()) + }, + redis_url: config.redis_url.clone(), + lock_ttl_seconds: Some(900), + redis_ssl: config.redis_ssl.clone(), + challenge_max_ttl_seconds: Some(3600), + }, + }) +} + diff --git a/src/acme/errors.rs b/src/acme/errors.rs index 8f1dd14..b1d8205 100644 --- a/src/acme/errors.rs +++ b/src/acme/errors.rs @@ -1,4 +1,4 @@ -use anyhow::Result; - -pub type AtomicServerResult = Result; - +use anyhow::Result; + +pub type AtomicServerResult = Result; + diff --git a/src/acme/lib.rs b/src/acme/lib.rs index fd52900..0923015 100644 --- a/src/acme/lib.rs +++ b/src/acme/lib.rs @@ -1,1582 +1,1582 @@ -//! Everything required for setting up HTTPS / TLS. -//! Instantiate a server for HTTP-01 check with letsencrypt, -//! checks if certificates are not outdated, -//! persists files on disk. - -use crate::acme::{Config, AppConfig, RetryConfig, AtomicServerResult}; -use crate::acme::{DomainConfig, DomainReaderFactory}; -use crate::acme::{Storage, StorageFactory, StorageType}; - -use actix_web::{App, HttpServer, HttpResponse, web, Responder}; -use anyhow::{anyhow, Context}; -use serde::Serialize; -use std::io::BufReader; -use std::sync::Arc; -use std::sync::RwLock as StdRwLock; -use once_cell::sync::OnceCell; -use tracing::{info, warn, debug}; -use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; - -/// Global proxy_certificates path (set at startup) -static PROXY_CERTIFICATES_PATH: OnceCell>>> = OnceCell::new(); - -/// Set the proxy_certificates path (called from main.rs) -pub fn set_proxy_certificates_path(path: Option) { - let path_arc = PROXY_CERTIFICATES_PATH.get_or_init(|| { - Arc::new(StdRwLock::new(None)) - }); - if let Ok(mut path_guard) = path_arc.write() { - *path_guard = path; - } -} - -/// Get the proxy_certificates path -fn get_proxy_certificates_path() -> Option { - PROXY_CERTIFICATES_PATH.get() - .and_then(|path_arc| { - path_arc.read().ok() - .and_then(|guard| guard.clone()) - }) -} - -/// Normalize PEM certificate chain to ensure proper format -/// - Ensures newline between certificates (END CERTIFICATE and BEGIN CERTIFICATE) -/// - Ensures file ends with newline -fn normalize_pem_chain(chain: &str) -> String { - let mut normalized = chain.to_string(); - - // Ensure newline between END CERTIFICATE and BEGIN CERTIFICATE - normalized = normalized.replace("-----END CERTIFICATE----------BEGIN CERTIFICATE-----", - "-----END CERTIFICATE-----\n-----BEGIN CERTIFICATE-----"); - - // Ensure newline between END CERTIFICATE and BEGIN PRIVATE KEY (for key files) - normalized = normalized.replace("-----END CERTIFICATE----------BEGIN PRIVATE KEY-----", - "-----END CERTIFICATE-----\n-----BEGIN PRIVATE KEY-----"); - - // Ensure file ends with newline - if !normalized.ends_with('\n') { - normalized.push('\n'); - } - - normalized -} - -/// Save certificate to proxy_certificates path in the format expected by the proxy -/// Format: {sanitized_domain}.crt and {sanitized_domain}.key -async fn save_cert_to_proxy_path( - domain: &str, - fullchain: &str, - private_key: &str, - proxy_certificates_path: &str, -) -> anyhow::Result<()> { - use std::path::Path; - use tokio::fs; - use tokio::io::AsyncWriteExt; - - // Create directory if it doesn't exist - let cert_dir = Path::new(proxy_certificates_path); - fs::create_dir_all(cert_dir).await - .with_context(|| format!("Failed to create proxy_certificates directory: {}", proxy_certificates_path))?; - - // Normalize domain name (remove wildcard prefix if present) before sanitizing - // This ensures the filename matches what the certificate worker expects - let normalized_domain = domain.strip_prefix("*.").unwrap_or(domain); - // Sanitize domain name for filename (replace . with _ and * with _) - let sanitized_domain = normalized_domain.replace('.', "_").replace('*', "_"); - let cert_path = cert_dir.join(format!("{}.crt", sanitized_domain)); - let key_path = cert_dir.join(format!("{}.key", sanitized_domain)); - - // Normalize PEM format - let normalized_fullchain = normalize_pem_chain(fullchain); - let normalized_key = normalize_pem_chain(private_key); - - // Write certificate file - let mut cert_file = fs::File::create(&cert_path).await - .with_context(|| format!("Failed to create certificate file: {}", cert_path.display()))?; - cert_file.write_all(normalized_fullchain.as_bytes()).await - .with_context(|| format!("Failed to write certificate file: {}", cert_path.display()))?; - cert_file.sync_all().await - .with_context(|| format!("Failed to sync certificate file: {}", cert_path.display()))?; - - // Write key file - let mut key_file = fs::File::create(&key_path).await - .with_context(|| format!("Failed to create key file: {}", key_path.display()))?; - key_file.write_all(normalized_key.as_bytes()).await - .with_context(|| format!("Failed to write key file: {}", key_path.display()))?; - key_file.sync_all().await - .with_context(|| format!("Failed to sync key file: {}", key_path.display()))?; - - info!("Saved certificate for domain '{}' to proxy_certificates path: {} (cert: {}, key: {})", - domain, proxy_certificates_path, cert_path.display(), key_path.display()); - - Ok(()) -} - -/// Create RUSTLS server config from certificates in storage -pub fn get_https_config( - config: &Config, -) -> AtomicServerResult { - use rustls_pemfile::{certs, pkcs8_private_keys}; - use rustls::pki_types::{CertificateDer, PrivateKeyDer}; - - // Create storage backend (file system by default) - let storage = StorageFactory::create_default(config)?; - - // Read fullchain synchronously (rustls requires sync) - // Use fullchain which includes both cert and chain - let fullchain_bytes = storage.read_fullchain_sync() - .ok_or_else(|| anyhow!("Storage backend does not support synchronous fullchain reading"))??; - - let key_bytes = storage.read_key_sync() - .ok_or_else(|| anyhow!("Storage backend does not support synchronous key reading"))??; - - let cert_file = &mut BufReader::new(std::io::Cursor::new(fullchain_bytes)); - let key_file = &mut BufReader::new(std::io::Cursor::new(key_bytes)); - - let mut cert_chain = Vec::new(); - for cert_result in certs(cert_file) { - let cert = cert_result.context("Failed to parse certificate")?; - cert_chain.push(CertificateDer::from(cert)); - } - - let mut keys: Vec = pkcs8_private_keys(key_file) - .collect::, _>>() - .context("Failed to parse private key")? - .into_iter() - .map(PrivateKeyDer::Pkcs8) - .collect(); - - if keys.is_empty() { - return Err(anyhow!("No key found. Consider deleting the storage directory and restart to create new keys.")); - } - - let server_config = rustls::ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(cert_chain, keys.remove(0)) - .context("Unable to create HTTPS config from certificates")?; - - Ok(server_config) -} - -/// Check if a failed certificate should be retried based on exponential backoff -pub async fn should_retry_failed_cert( - config: &Config, - retry_config: &RetryConfig, -) -> AtomicServerResult { - let storage = StorageFactory::create_default(config)?; - - // Check if there's a failure record - let last_failure = match storage.get_last_failure().await { - Ok(Some((timestamp, _))) => timestamp, - Ok(None) => return Ok(false), // No failure recorded - Err(e) => { - warn!("Failed to read failure record: {}", e); - return Ok(false); - } - }; - - // Check if max retries exceeded - let failure_count = storage.get_failure_count().await.unwrap_or(0); - if retry_config.max_retries > 0 && failure_count >= retry_config.max_retries { - warn!("Maximum retry count ({}) exceeded for domain {}. Skipping retry.", retry_config.max_retries, config.opts.domain); - return Ok(false); - } - - // Calculate exponential backoff delay - // Formula: min(min_retry_delay * 2^(failure_count - 1), max_retry_delay) - let base_delay = retry_config.min_retry_delay_seconds as f64; - let exponential_delay = base_delay * (2.0_f64.powi((failure_count.saturating_sub(1)) as i32)); - let delay_seconds = exponential_delay.min(retry_config.max_retry_delay_seconds as f64) as u64; - - let now = chrono::Utc::now(); - let time_since_failure = now - last_failure; - let time_since_failure_secs = time_since_failure.num_seconds() as u64; - - if time_since_failure_secs >= delay_seconds { - info!("Retry delay ({}) has passed for domain {}. Last failure was {} seconds ago. Will retry.", delay_seconds, config.opts.domain, time_since_failure_secs); - Ok(true) - } else { - let remaining = delay_seconds - time_since_failure_secs; - info!("Retry delay not yet reached for domain {}. Will retry in {} seconds.", config.opts.domain, remaining); - Ok(false) - } -} - -/// Checks if the certificates need to be renewed. -/// Will be true if there are no certs yet. -pub async fn should_renew_certs_check(config: &Config) -> AtomicServerResult { - let storage = StorageFactory::create_default(config)?; - - if !storage.cert_exists().await { - info!( - "No HTTPS certificates found, requesting new ones...", - ); - return Ok(true); - } - - // Ensure certificate hash exists (generate if missing for backward compatibility) - if let Err(e) = storage.get_certificate_hash().await { - warn!("Failed to get or generate certificate hash: {}", e); - } - - let created_at = match storage.read_created_at().await { - Ok(dt) => dt, - Err(_) => { - // If we can't read the created_at file, assume certificates need renewal - warn!("Unable to read certificate creation timestamp, assuming renewal needed"); - return Ok(true); - } - }; - - let certs_age: chrono::Duration = chrono::Utc::now() - created_at; - // Let's Encrypt certificates are valid for three months, but I think renewing earlier provides a better UX - let expired = certs_age > chrono::Duration::weeks(4); - if expired { - warn!("HTTPS Certificates expired, requesting new ones...") - }; - Ok(expired) -} - -#[derive(Debug, Serialize)] -struct CertificateExpirationInfo { - domain: String, - exists: bool, - created_at: Option, - expires_at: Option, - age_days: Option, - expires_in_days: Option, - needs_renewal: bool, - #[serde(default)] - renewing: bool, -} - -/// Get certificate expiration information for a domain -async fn get_cert_expiration_info( - app_config: &AppConfig, - domain: &str, - base_https_path: &std::path::PathBuf, -) -> anyhow::Result { - let domain_cfg = { - let domain_config = DomainConfig { - domain: domain.to_string(), - email: None, - dns: false, - wildcard: false, - }; - app_config.create_domain_config(&domain_config, base_https_path.clone()) - }; - - let storage = StorageFactory::create_default(&domain_cfg)?; - let exists = storage.cert_exists().await; - - if !exists { - return Ok(CertificateExpirationInfo { - domain: domain.to_string(), - exists: false, - created_at: None, - expires_at: None, - age_days: None, - expires_in_days: None, - needs_renewal: true, - renewing: false, - }); - } - - // Ensure certificate hash exists (generate if missing for backward compatibility) - if let Err(e) = storage.get_certificate_hash().await { - warn!("Failed to get or generate certificate hash for {}: {}", domain, e); - } - - let created_at = match storage.read_created_at().await { - Ok(dt) => dt, - Err(_) => { - return Ok(CertificateExpirationInfo { - domain: domain.to_string(), - exists: true, - created_at: None, - expires_at: None, - age_days: None, - expires_in_days: None, - needs_renewal: true, - renewing: false, - }); - } - }; - - // Let's Encrypt certificates are valid for 90 days (3 months) - let expires_at = created_at + chrono::Duration::days(90); - let now = chrono::Utc::now(); - let age = now - created_at; - let expires_in = expires_at - now; - - let needs_renewal = age > chrono::Duration::weeks(4); - - Ok(CertificateExpirationInfo { - domain: domain.to_string(), - exists: true, - created_at: Some(created_at.to_rfc3339()), - expires_at: Some(expires_at.to_rfc3339()), - age_days: Some(age.num_days()), - expires_in_days: Some(expires_in.num_days()), - needs_renewal, - renewing: false, - }) -} - -/// HTTP handler for certificate expiration check (single domain) -async fn check_cert_expiration_handler( - app_config: web::Data, - base_path: web::Data, - path: web::Path, -) -> impl Responder { - let domain = path.into_inner(); - match get_cert_expiration_info(&app_config, &domain, &base_path).await { - Ok(mut info) => { - // If certificate needs renewal, start renewal process in background - if info.needs_renewal { - // Read domains to find the domain config - let domain_reader = match DomainReaderFactory::create(&app_config.domains) { - Ok(reader) => reader, - Err(e) => { - warn!("Error creating domain reader: {}", e); - return HttpResponse::Ok().json(info); - } - }; - - if let Ok(domains) = domain_reader.read_domains().await { - if let Some(domain_config) = domains.iter().find(|d| d.domain == domain) { - let app_config_clone = app_config.clone(); - let base_path_clone = base_path.clone(); - let domain_config_clone = domain_config.clone(); - - // Spawn renewal task in background - tokio::spawn(async move { - if let Err(e) = renew_cert_if_needed(&app_config_clone, &domain_config_clone, &base_path_clone).await { - warn!("Error renewing certificate for {}: {}", domain_config_clone.domain, e); - } - }); - - info.renewing = true; // Mark as renewing - } - } - } - HttpResponse::Ok().json(info) - } - Err(e) => { - warn!("Error checking certificate expiration for {}: {}", domain, e); - HttpResponse::InternalServerError().json(serde_json::json!({ - "error": format!("Failed to check certificate expiration: {}", e) - })) - } - } -} - -/// Renew certificate for a domain if needed -async fn renew_cert_if_needed( - app_config: &AppConfig, - domain_config: &DomainConfig, - base_path: &std::path::PathBuf, -) -> anyhow::Result<()> { - let domain_cfg = app_config.create_domain_config(domain_config, base_path.clone()); - - if should_renew_certs_check(&domain_cfg).await? { - info!("Certificate for {} is expiring, starting renewal process...", domain_config.domain); - request_cert(&domain_cfg).await?; - info!("Certificate renewed successfully for {}!", domain_config.domain); - } - - Ok(()) -} - -/// HTTP handler for checking expiration of all domains -async fn check_all_certs_expiration_handler( - app_config: web::Data, - base_path: web::Data, -) -> impl Responder { - // Read domains from the configured source - let domain_reader = match DomainReaderFactory::create(&app_config.domains) { - Ok(reader) => reader, - Err(e) => { - warn!("Error creating domain reader: {}", e); - return HttpResponse::InternalServerError().json(serde_json::json!({ - "error": format!("Failed to create domain reader: {}", e) - })); - } - }; - - let domains = match domain_reader.read_domains().await { - Ok(domains) => domains, - Err(e) => { - warn!("Error reading domains: {}", e); - return HttpResponse::InternalServerError().json(serde_json::json!({ - "error": format!("Failed to read domains: {}", e) - })); - } - }; - - // Check expiration for each domain and renew if needed - let mut results = Vec::new(); - for domain_config in domains.iter() { - match get_cert_expiration_info(&app_config, &domain_config.domain, &base_path).await { - Ok(mut info) => { - // If certificate needs renewal, start renewal process in background - if info.needs_renewal { - let app_config_clone = app_config.clone(); - let base_path_clone = base_path.clone(); - let domain_config_clone = domain_config.clone(); - - // Spawn renewal task in background - tokio::spawn(async move { - if let Err(e) = renew_cert_if_needed(&app_config_clone, &domain_config_clone, &base_path_clone).await { - warn!("Error renewing certificate for {}: {}", domain_config_clone.domain, e); - } - }); - - info.renewing = true; // Mark as renewing - } - results.push(info); - } - Err(e) => { - warn!("Error checking certificate expiration for {}: {}", domain_config.domain, e); - // Add error info for this domain - results.push(CertificateExpirationInfo { - domain: domain_config.domain.clone(), - exists: false, - created_at: None, - expires_at: None, - age_days: None, - expires_in_days: None, - needs_renewal: true, - renewing: false, - }); - } - } - } - - HttpResponse::Ok().json(results) -} - -/// Check DNS TXT record for DNS-01 challenge -async fn check_dns_txt_record(record_name: &str, expected_value: &str, max_attempts: u32, delay_seconds: u64) -> bool { - use trust_dns_resolver::TokioAsyncResolver; - - // Use Google DNS as primary resolver (more reliable than system DNS) - // This ensures we're querying authoritative DNS servers - let resolver_config = ResolverConfig::google(); - - info!("Checking DNS TXT record: {} (expected value: {})", record_name, expected_value); - info!("DNS lookup settings: max_attempts={}, delay_seconds={}", max_attempts, delay_seconds); - - for attempt in 1..=max_attempts { - // Create a new resolver for each attempt to ensure no caching - let mut resolver_opts = ResolverOpts::default(); - resolver_opts.use_hosts_file = true; - resolver_opts.validate = false; // Don't validate DNSSEC to avoid issues - resolver_opts.attempts = 3; // Retry attempts per query - resolver_opts.timeout = std::time::Duration::from_secs(5); // 5 second timeout - resolver_opts.cache_size = 0; // Disable DNS cache by setting cache size to 0 - - // Create a fresh DNS resolver for each attempt to avoid any caching - let resolver = TokioAsyncResolver::tokio( - resolver_config.clone(), - resolver_opts, - ); - - match resolver.txt_lookup(record_name).await { - Ok(lookup) => { - let mut found_any = false; - let mut found_values = Vec::new(); - - // Check if any TXT record matches the expected value - for record in lookup.iter() { - for txt_data in record.iter() { - let txt_string = String::from_utf8_lossy(txt_data).trim().to_string(); - found_any = true; - found_values.push(txt_string.clone()); - - if txt_string == expected_value { - info!("DNS TXT record matches expected value on attempt {}: {}", attempt, txt_string); - return true; - } - } - } - - if found_any { - if attempt == 1 || attempt % 6 == 0 { - warn!("DNS record found but value doesn't match. Expected: '{}', Found: {:?}", expected_value, found_values); - } - } else { - if attempt % 6 == 0 { - info!("DNS record not found yet (attempt {}/{})...", attempt, max_attempts); - } - } - } - Err(e) => { - if attempt == 1 || attempt % 6 == 0 { - warn!("DNS lookup error on attempt {}: {}", attempt, e); - } - } - } - - if attempt < max_attempts { - tokio::time::sleep(tokio::time::Duration::from_secs(delay_seconds)).await; - } - } - - warn!("DNS TXT record not found after {} attempts", max_attempts); - false -} - -/// Check if ACME challenge endpoint is available -/// This verifies that the ACME server is running and accessible before requesting certificates -/// Retries with exponential backoff to handle cases where the server is still starting -async fn check_acme_challenge_endpoint(config: &Config) -> anyhow::Result<()> { - use std::time::Duration; - - // Build the ACME server URL (typically 127.0.0.1:9180) - let acme_url = format!("http://{}:{}/.well-known/acme-challenge/test-endpoint-check", config.opts.ip, config.opts.port); - - debug!("Checking if ACME challenge endpoint is available at {}", acme_url); - - // Create HTTP client with timeout - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(5)) - .build() - .context("Failed to create HTTP client for endpoint check")?; - - // Retry logic: try up to 5 times with exponential backoff - // This handles cases where the ACME server is still starting up - let max_retries = 5; - let mut retry_delay = Duration::from_millis(10); // Start with 500ms - - for attempt in 1..=max_retries { - match client.get(&acme_url).send().await { - Ok(response) => { - let status = response.status(); - if status.is_success() || status.as_u16() == 404 { - if attempt > 1 { - debug!("ACME challenge endpoint is now available (status: {}) after {} attempt(s)", status, attempt); - } else { - debug!("ACME challenge endpoint is available (status: {})", status); - } - return Ok(()); - } else { - return Err(anyhow::anyhow!("ACME server returned unexpected status: {}", status)); - } - } - Err(e) => { - let error_msg = e.to_string(); - let is_connection_error = error_msg.contains("Connection refused") - || error_msg.contains("connect") - || error_msg.contains("connection") - || error_msg.contains("refused") - || e.is_connect() - || e.is_timeout(); - - if attempt < max_retries && is_connection_error { - debug!("ACME server not ready yet (attempt {}/{}), retrying in {:?}...", attempt, max_retries, retry_delay); - tokio::time::sleep(retry_delay).await; - // Exponential backoff: 10ms, 20ms, 40ms, 80ms, 160ms (user changed from 500ms) - retry_delay = retry_delay * 2; - continue; - } - // Last attempt or non-connection error - return error - if attempt >= max_retries { - return Err(anyhow::anyhow!("Failed to connect to ACME server at {} after {} attempts: {}", acme_url, max_retries, e)); - } else { - return Err(anyhow::anyhow!("Failed to connect to ACME server at {} (non-retryable error): {}", acme_url, e)); - } - } - } - } - - Err(anyhow::anyhow!("Failed to connect to ACME server at {} after {} attempts", acme_url, max_retries)) -} - -/// Writes challenge file for HTTP-01 challenge -/// The main HTTP server will serve this file - no temporary server needed -async fn cert_init_server( - config: &Config, - challenge: &instant_acme::Challenge, - key_auth: &str, -) -> AtomicServerResult<()> { - let storage = StorageFactory::create_default(config)?; - storage.write_challenge(&challenge.token.to_string(), key_auth).await?; - - info!("Challenge file written. Main HTTP server will serve it at /.well-known/acme-challenge/{}", challenge.token); - - Ok(()) -} - -/// Sends a request to LetsEncrypt to create a certificate -pub async fn request_cert(config: &Config) -> AtomicServerResult<()> { - // Always use Redis storage (storage_type option is kept for compatibility but always uses Redis) - let storage_type = StorageType::Redis; - - if storage_type == StorageType::Redis { - // Use distributed lock for Redis storage to prevent multiple instances from processing the same domain - // Create RedisStorage directly to access lock methods - let redis_storage = crate::acme::storage::RedisStorage::new(config)?; - - // Lock TTL from config (default: 900 seconds = 15 minutes) - let lock_ttl_seconds = config.opts.lock_ttl_seconds.unwrap_or(900); - - return redis_storage.with_lock(lock_ttl_seconds, || async { - request_cert_internal(config).await - }).await; - } - - // Redis storage always uses distributed lock (above) - request_cert_internal(config).await -} - -/// Parse retry-after timestamp from rate limit error message -/// Returns the retry-after timestamp if found, None otherwise -/// Handles both timestamp formats and ISO 8601 duration formats -fn parse_retry_after(error_msg: &str) -> Option> { - use chrono::{Utc, Duration}; - - // Look for "retry after" pattern (case insensitive) - let error_msg_lower = error_msg.to_lowercase(); - if let Some(pos) = error_msg_lower.find("retry after") { - // Get the text after "retry after" from the original message (preserve case for parsing) - let after_pos = error_msg[pos + "retry after".len()..].find(|c: char| !c.is_whitespace()) - .unwrap_or(0); - let mut after_text_str = error_msg[pos + "retry after".len() + after_pos..].trim().to_string(); - - // Try to find the end of the timestamp/duration - // For timestamps like "2025-11-14 21:13:29 UTC", stop at end of line or before URL/links - // Look for common patterns that indicate end of timestamp: - // - End of string - // - Before URLs (http:// or https://) - // - Before "see" or ":" followed by URL - if let Some(url_pos) = after_text_str.find("http://").or_else(|| after_text_str.find("https://")) { - after_text_str = after_text_str[..url_pos].trim().to_string(); - } - // Extract timestamp - format is typically "2025-11-14 21:13:29 UTC" followed by ": see https://..." - // Simplest approach: find " UTC" and take everything up to and including it - if let Some(utc_pos) = after_text_str.find(" UTC") { - // Found " UTC", extract up to and including it (this is the complete timestamp) - after_text_str = after_text_str[..utc_pos + 4].trim().to_string(); - } else { - // No " UTC" found, try to stop before URLs or "see" keyword - if let Some(url_pos) = after_text_str.find("http://").or_else(|| after_text_str.find("https://")) { - after_text_str = after_text_str[..url_pos].trim().to_string(); - } - if let Some(see_pos) = after_text_str.find(" see ") { - after_text_str = after_text_str[..see_pos].trim().to_string(); - } - } - - let after_text = after_text_str.as_str(); - - // First, try to parse as timestamp (format: "2025-11-10 18:08:38 UTC") - // Try with timezone first - if let Ok(dt) = chrono::DateTime::parse_from_str(after_text, "%Y-%m-%d %H:%M:%S %Z") { - return Some(dt.with_timezone(&chrono::Utc)); - } - // Try with "UTC" as separate word (common format: "2025-11-14 21:13:29 UTC") - if after_text.ends_with(" UTC") { - let without_tz = &after_text[..after_text.len() - 4].trim(); - if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(without_tz, "%Y-%m-%d %H:%M:%S") { - return Some(chrono::DateTime::from_naive_utc_and_offset(dt, chrono::Utc)); - } - } - // Try alternative format without timezone (assume UTC) - if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(after_text, "%Y-%m-%d %H:%M:%S") { - return Some(chrono::DateTime::from_naive_utc_and_offset(dt, chrono::Utc)); - } - // Try parsing as RFC3339 format - if let Ok(dt) = after_text.parse::>() { - return Some(dt); - } - // Try ISO 8601/RFC3339 format (chrono doesn't have parse_from_rfc3339, use parse_from_str with RFC3339 format) - // RFC3339 format: "2025-11-14T21:13:29Z" or "2025-11-14T21:13:29+00:00" - if let Ok(dt) = chrono::DateTime::parse_from_str(after_text, "%+") { - return Some(dt.with_timezone(&chrono::Utc)); - } - - // Try parsing as ISO 8601 duration (e.g., "PT86225.992004616S" or "PT24H") - // This happens when the error message contains a duration instead of a timestamp - if after_text.starts_with("PT") { - // Parse ISO 8601 duration: PT[nH][nM][nS] or PT[n]S - // Handle case where duration ends with 'S' (seconds) - let duration_str = if after_text.ends_with('S') && !after_text.ends_with("MS") && !after_text.ends_with("HS") { - &after_text[2..after_text.len()-1] // Remove "PT" prefix and "S" suffix - } else { - &after_text[2..] // Just remove "PT" prefix - }; - - // Try to parse as seconds (e.g., "86225.992004616") - if let Ok(seconds) = duration_str.parse::() { - let duration = Duration::seconds(seconds as i64) + Duration::nanoseconds((seconds.fract() * 1_000_000_000.0) as i64); - return Some(Utc::now() + duration); - } - - // Try to parse hours, minutes, seconds separately - let mut total_seconds = 0.0; - let mut current_num = String::new(); - let mut current_unit = String::new(); - - for ch in duration_str.chars() { - if ch.is_ascii_digit() || ch == '.' { - if !current_unit.is_empty() { - // Process previous unit - if let Ok(val) = current_num.parse::() { - match current_unit.as_str() { - "H" => total_seconds += val * 3600.0, - "M" => total_seconds += val * 60.0, - "S" => total_seconds += val, - _ => {} - } - } - current_num.clear(); - current_unit.clear(); - } - current_num.push(ch); - } else if ch.is_ascii_alphabetic() { - current_unit.push(ch); - } - } - - // Process last unit - if !current_unit.is_empty() && !current_num.is_empty() { - if let Ok(val) = current_num.parse::() { - match current_unit.as_str() { - "H" => total_seconds += val * 3600.0, - "M" => total_seconds += val * 60.0, - "S" => total_seconds += val, - _ => {} - } - } - } - - if total_seconds > 0.0 { - let duration = Duration::seconds(total_seconds as i64) + Duration::nanoseconds((total_seconds.fract() * 1_000_000_000.0) as i64); - return Some(Utc::now() + duration); - } - } - } - None -} - -/// Helper function to check if an account already exists -async fn check_account_exists( - email: &str, - lets_encrypt_url: &str, -) -> Result, anyhow::Error> { - match instant_acme::Account::builder() - .context("Failed to create account builder")? - .create( - &instant_acme::NewAccount { - contact: &[&format!("mailto:{}", email)], - terms_of_service_agreed: true, - only_return_existing: true, - }, - lets_encrypt_url.to_string(), - None, - ) - .await - { - Ok((acc, cr)) => Ok(Some((acc, cr))), - Err(e) => { - let error_msg = format!("{}", e); - // If it's a rate limit error, propagate it - if error_msg.contains("rateLimited") || error_msg.contains("rate limit") || error_msg.contains("too many") { - return Err(e.into()); - } - // Otherwise, account doesn't exist - Ok(None) - } - } -} - -/// Helper function to create a new Let's Encrypt account and save credentials -/// Handles rate limits by waiting for the retry-after time -async fn create_new_account( - storage: &Box, - email: &str, - lets_encrypt_url: &str, -) -> AtomicServerResult<(instant_acme::Account, instant_acme::AccountCredentials)> { - // First, check if account already exists - match check_account_exists(email, lets_encrypt_url).await { - Ok(Some((acc, cr))) => { - info!("Account already exists for email {}, reusing it", email); - return Ok((acc, cr)); - } - Ok(None) => { - // Account doesn't exist, proceed to create - } - Err(e) => { - // Check if it's a rate limit error - let error_msg = format!("{}", e); - if error_msg.contains("rateLimited") || error_msg.contains("rate limit") || error_msg.contains("too many") { - if let Some(retry_after) = parse_retry_after(&error_msg) { - let now = chrono::Utc::now(); - if retry_after > now { - let wait_duration = retry_after - now; - let wait_secs = wait_duration.num_seconds().max(0) as u64; - warn!("Rate limit hit. Waiting {} seconds until {} before retrying account creation", wait_secs, retry_after); - tokio::time::sleep(tokio::time::Duration::from_secs(wait_secs + 1)).await; - } - } else { - // Rate limit error but couldn't parse retry-after, wait a default time - warn!("Rate limit hit but couldn't parse retry-after time. Waiting 3 hours (10800 seconds) before retrying"); - tokio::time::sleep(tokio::time::Duration::from_secs(10800)).await; - } - } else { - // Not a rate limit error, propagate it - return Err(e); - } - } - } - - info!("Creating new LetsEncrypt account with email {}", email); - - // Retry account creation (after waiting for rate limit if needed) - let max_retries = 3; - let mut retry_count = 0; - - loop { - match instant_acme::Account::builder() - .context("Failed to create account builder")? - .create( - &instant_acme::NewAccount { - contact: &[&format!("mailto:{}", email)], - terms_of_service_agreed: true, - only_return_existing: false, - }, - lets_encrypt_url.to_string(), - None, - ) - .await - { - Ok((account, creds)) => { - // Save credentials for future use (store as JSON value for now) - if let Ok(creds_json) = serde_json::to_string(&creds) { - if let Err(e) = storage.write_account_credentials(&creds_json).await { - warn!("Failed to save account credentials to storage: {}. Account will be recreated on next run.", e); - } else { - info!("Saved LetsEncrypt account credentials to storage"); - } - } else { - warn!("Failed to serialize account credentials. Account will be recreated on next run."); - } - return Ok((account, creds)); - } - Err(e) => { - let error_msg = format!("{}", e); - - // Check if it's a rate limit error - if error_msg.contains("rateLimited") || error_msg.contains("rate limit") || error_msg.contains("too many") { - if let Some(retry_after) = parse_retry_after(&error_msg) { - let now = chrono::Utc::now(); - if retry_after > now { - let wait_duration = retry_after - now; - let wait_secs = wait_duration.num_seconds().max(0) as u64; - warn!("Rate limit hit during account creation. Waiting {} seconds until {} before retrying", wait_secs, retry_after); - tokio::time::sleep(tokio::time::Duration::from_secs(wait_secs + 1)).await; - retry_count += 1; - if retry_count < max_retries { - continue; - } - } - } else { - // Rate limit error but couldn't parse retry-after - if retry_count < max_retries { - let wait_secs = 10800; // 3 hours default - warn!("Rate limit hit but couldn't parse retry-after time. Waiting {} seconds before retrying", wait_secs); - tokio::time::sleep(tokio::time::Duration::from_secs(wait_secs)).await; - retry_count += 1; - continue; - } - } - } - - // If we've exhausted retries or it's not a rate limit error, return the error - return Err(e).context("Failed to create account"); - } - } - } -} - -async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { - use instant_acme::OrderStatus; - - // Detect wildcard domain and automatically use DNS-01 - let is_wildcard = config.opts.domain.starts_with("*."); - let use_dns = config.opts.https_dns || is_wildcard; - - if is_wildcard && !config.opts.https_dns { - warn!("Wildcard domain detected ({}), automatically using DNS-01 challenge", config.opts.domain); - } - - let challenge_type = if use_dns { - debug!("Using DNS-01 challenge"); - instant_acme::ChallengeType::Dns01 - } else { - debug!("Using HTTP-01 challenge"); - // Check if ACME challenge endpoint is available before proceeding - if let Err(e) = check_acme_challenge_endpoint(config).await { - let error_msg = format!("ACME challenge endpoint not available for HTTP-01 challenge: {}. Skipping certificate request.", e); - warn!("{}", error_msg); - let storage = StorageFactory::create_default(config)?; - if let Err(record_err) = storage.record_failure(&error_msg).await { - warn!("Failed to record failure: {}", record_err); - } - return Ok(()); - } - instant_acme::ChallengeType::Http01 - }; - - // Create a new account. This will generate a fresh ECDSA key for you. - // Alternatively, restore an account from serialized credentials by - // using `Account::from_credentials()`. - - let lets_encrypt_url = if config.opts.development { - warn!( - "Using LetsEncrypt staging server, not production. This is for testing purposes only and will not provide a working certificate." - ); - instant_acme::LetsEncrypt::Staging.url() - } else { - instant_acme::LetsEncrypt::Production.url() - }; - - let email = - config.opts.email.clone().expect( - "No email set - required for HTTPS certificate initialization with LetsEncrypt", - ); - - // Try to load existing account credentials from storage - let storage = StorageFactory::create_default(config)?; - let existing_creds = storage.read_account_credentials().await - .context("Failed to read account credentials from storage")?; - - // Try to restore account from stored credentials, but fall back to creating new account if it fails - let (account, _creds) = match existing_creds { - Some(creds_json) => { - // Try to restore account from existing credentials - debug!("Attempting to restore LetsEncrypt account from stored credentials"); - - // First try to parse and restore from stored credentials - match serde_json::from_str::(&creds_json) { - Ok(creds) => { - // Try to restore account from credentials - // Use AccountBuilder to restore from credentials - match instant_acme::Account::builder() - .context("Failed to create account builder")? - .from_credentials(creds) - .await - { - Ok(acc) => { - debug!("Successfully restored LetsEncrypt account from stored credentials"); - // Get the credentials back from the account (they're stored in the account) - // For now, we'll use the stored credentials JSON - let restored_creds = serde_json::from_str::(&creds_json) - .expect("Credentials were just parsed successfully"); - (acc, restored_creds) - } - Err(e) => { - let error_msg = format!("{}", e); - warn!("Failed to restore account from stored credentials: {}. Will check if account exists.", error_msg); - - // If restoration fails, check if account exists - match check_account_exists(&email, lets_encrypt_url).await { - Ok(Some((acc, cr))) => { - info!("Account exists but credentials were invalid. Using existing account."); - (acc, cr) - } - Ok(None) => { - warn!("Stored credentials invalid and account doesn't exist. Creating new account."); - create_new_account(&storage, &email, lets_encrypt_url).await? - } - Err(e) => { - let error_msg = format!("{}", e); - if error_msg.contains("rateLimited") || error_msg.contains("rate limit") || error_msg.contains("too many") { - warn!("Rate limit hit while checking account. Will wait and retry in create_new_account."); - create_new_account(&storage, &email, lets_encrypt_url).await? - } else { - warn!("Failed to check account existence: {}. Creating new account.", e); - create_new_account(&storage, &email, lets_encrypt_url).await? - } - } - } - } - } - } - Err(e) => { - warn!("Failed to parse stored credentials: {}. Creating new account.", e); - create_new_account(&storage, &email, lets_encrypt_url).await? - } - } - } - None => { - // No stored credentials, create a new account - create_new_account(&storage, &email, lets_encrypt_url).await? - } - }; - - // Create the ACME order based on the given domain names. - // Note that this only needs an `&Account`, so the library will let you - // process multiple orders in parallel for a single account. - - // Prepare domain for ACME order - // For wildcard domains (*.example.com), we need to request *.example.com - // For non-wildcard domains with DNS-01, we request the domain as-is - let domain = config.opts.domain.clone(); - let is_wildcard_domain = domain.starts_with("*."); - - if is_wildcard_domain { - // Domain already has wildcard prefix, use as-is for ACME order - // ACME requires *.example.com format for wildcard certificates - debug!("Requesting wildcard certificate for: {}", domain); - } else if use_dns { - // Non-wildcard domain with DNS-01 challenge - use domain as-is - debug!("Requesting certificate for domain with DNS-01: {}", domain); - } else { - // HTTP-01 challenge - use domain as-is - debug!("Requesting certificate for domain with HTTP-01: {}", domain); - } - - // Check if we're still in rate limit period before attempting request - use chrono::{Utc, Duration}; - let storage = StorageFactory::create_default(config)?; - if let Ok(Some((last_failure_time, last_failure_msg))) = storage.get_last_failure().await { - // Check if the last failure was a rate limit error - if last_failure_msg.contains("rateLimited") || - last_failure_msg.contains("rate limit") || - last_failure_msg.contains("too many certificates") { - // Parse retry-after time from error message - if let Some(retry_after) = parse_retry_after(&last_failure_msg) { - let now = Utc::now(); - if now < retry_after { - let wait_duration = retry_after - now; - info!("Rate limit still active for domain {}: retry after {} ({} remaining). Skipping certificate request.", - config.opts.domain, retry_after, wait_duration); - return Ok(()); - } else { - debug!("Rate limit period has passed for domain {}. Proceeding with certificate request.", config.opts.domain); - } - } else { - // Log the error message for debugging - tracing::debug!("Failed to parse retry-after from error message: {}", last_failure_msg); - // Can't parse retry-after from error message - // Try to extract duration from the error message if it contains ISO 8601 duration - // Look for patterns like "PT86225S" or "PT24H" anywhere in the message - let mut found_duration = None; - for word in last_failure_msg.split_whitespace() { - if word.starts_with("PT") { - // Try to parse as ISO 8601 duration - if let Some(dt) = parse_retry_after(&format!("retry after {}", word)) { - found_duration = Some(dt); - break; - } - } - } - - if let Some(retry_after) = found_duration { - let now = Utc::now(); - if now < retry_after { - let wait_duration = retry_after - now; - info!("Rate limit still active for domain {}: retry after {} ({} remaining). Skipping certificate request.", - config.opts.domain, retry_after, wait_duration); - return Ok(()); - } - } else { - // Can't parse retry-after at all, use exponential backoff (24 hours minimum for rate limits) - let rate_limit_cooldown = Duration::hours(24); - let now = Utc::now(); - if now - last_failure_time < rate_limit_cooldown { - let remaining = rate_limit_cooldown - (now - last_failure_time); - warn!("Rate limit error detected for domain {} (retry-after time not parseable from: '{}'). Waiting {} before retry. Skipping certificate request.", - config.opts.domain, last_failure_msg, remaining); - return Ok(()); - } else { - debug!("Rate limit cooldown period has passed for domain {}. Proceeding with certificate request.", config.opts.domain); - } - } - } - } - } - - let identifier = instant_acme::Identifier::Dns(domain.clone()); - let identifiers = vec![identifier]; - let mut order = match account - .new_order(&instant_acme::NewOrder::new(&identifiers)) - .await - { - Ok(order) => order, - Err(e) => { - let error_msg = format!("Failed to create new order for domain {}: {}", config.opts.domain, e); - warn!("{}. Skipping certificate request.", error_msg); - - // If it's a rate limit error, store it with retry-after time - let is_rate_limit = error_msg.contains("rateLimited") || - error_msg.contains("rate limit") || - error_msg.contains("too many certificates"); - - if is_rate_limit { - if let Some(retry_after) = parse_retry_after(&error_msg) { - info!("Rate limit error for domain {}: will retry after {}", config.opts.domain, retry_after); - } else { - warn!("Rate limit error for domain {} but could not parse retry-after time. Will wait 24 hours before retry.", config.opts.domain); - } - } - - if let Err(record_err) = storage.record_failure(&error_msg).await { - warn!("Failed to record failure: {}", record_err); - } - return Ok(()); - } - }; - - // Check if order is already ready (from a previous request) - let initial_state = order.state(); - - // Handle unexpected order status - if !matches!(initial_state.status, instant_acme::OrderStatus::Pending | instant_acme::OrderStatus::Ready) { - let error_msg = format!("Unexpected order status: {:?} for domain {}", initial_state.status, config.opts.domain); - warn!("{}. Skipping certificate request.", error_msg); - if let Err(record_err) = storage.record_failure(&error_msg).await { - warn!("Failed to record failure: {}", record_err); - } - return Ok(()); - } - - // If order is already Ready, skip challenge processing - let state = if matches!(initial_state.status, instant_acme::OrderStatus::Ready) { - info!("Order is already in Ready state, skipping challenge processing and proceeding to finalization"); - // Use initial_state as the final state since we're skipping challenge processing - initial_state - } else { - // Order is Pending, proceed with challenge processing - // Pick the desired challenge type and prepare the response. - let mut authorizations = order.authorizations(); - let mut challenges_set = Vec::new(); - - while let Some(result) = authorizations.next().await { - let mut authz = match result { - Ok(authz) => authz, - Err(e) => { - warn!("Failed to get authorization: {}. Skipping this authorization.", e); - continue; - } - }; - let domain = authz.identifier().to_string(); - - match authz.status { - instant_acme::AuthorizationStatus::Pending => {} - instant_acme::AuthorizationStatus::Valid => continue, - _ => todo!(), - } - - let mut challenge = match authz.challenge(challenge_type.clone()) { - Some(c) => c, - None => { - warn!("Domain '{}': No {:?} challenge found, skipping", domain, challenge_type); - continue; - } - }; - - let key_auth = challenge.key_authorization().as_str().to_string(); - match challenge_type { - instant_acme::ChallengeType::Http01 => { - // Check if existing challenge is expired and clean it up - let storage = StorageFactory::create_default(config)?; - let challenge_token = challenge.token.to_string(); - if let Ok(Some(_)) = storage.get_challenge_timestamp(&challenge_token).await { - // Challenge exists, check if expired - let max_ttl = config.opts.challenge_max_ttl_seconds.unwrap_or(3600); - if let Ok(true) = storage.is_challenge_expired(&challenge_token, max_ttl).await { - info!("Existing challenge for token {} is expired (TTL: {}s), will be replaced", challenge_token, max_ttl); - } - } - - if let Err(e) = cert_init_server(config, &challenge, &key_auth).await { - warn!("Failed to write challenge file for HTTP-01 challenge: {}. Skipping HTTP-01 challenge.", e); - continue; - } - } - instant_acme::ChallengeType::Dns01 => { - // For DNS-01 challenge, the TXT record should be at _acme-challenge.{base_domain} - // For wildcard domains (*.example.com), use the base domain (example.com) - // For non-wildcard domains, use the domain as-is - // Use the is_wildcard flag computed earlier, or check the domain from authorization - let base_domain = if domain.starts_with("*.") { - // Domain from authorization starts with *. - strip it - domain.strip_prefix("*.").unwrap_or(&domain) - } else if is_wildcard { - // is_wildcard is true but domain doesn't start with *. - // This can happen if ACME returns the base domain instead of wildcard - // Use the domain as-is (it's already the base domain) - &domain - } else { - // For non-wildcard, use domain as-is - &domain - }; - let dns_record = format!("_acme-challenge.{}", base_domain); - let dns_value = challenge.key_authorization().dns_value(); - - info!("DNS-01 challenge for domain '{}' (base domain: {}, wildcard: {}):", domain, base_domain, is_wildcard); - info!(" Create DNS TXT record: {} IN TXT {}", dns_record, dns_value); - info!(" This record must be added to your DNS provider before the challenge can be validated."); - - // Check if existing DNS challenge is expired and clean it up - let storage = StorageFactory::create_default(config)?; - if let Ok(Some(_)) = storage.get_dns_challenge_timestamp(&domain).await { - // DNS challenge exists, check if expired - let max_ttl = config.opts.challenge_max_ttl_seconds.unwrap_or(3600); - if let Ok(true) = storage.is_dns_challenge_expired(&domain, max_ttl).await { - info!("Existing DNS challenge for domain {} is expired (TTL: {}s), will be replaced", domain, max_ttl); - } - } - - // Save DNS challenge code to storage (Redis or file) - if let Err(e) = storage.write_dns_challenge(&domain, &dns_record, &dns_value).await { - warn!("Failed to save DNS challenge code to storage: {}", e); - } - - info!("Waiting for DNS record to propagate..."); - - // Automatically check DNS records - let max_attempts = config.opts.dns_lookup_max_attempts.unwrap_or(100); - let delay_seconds = config.opts.dns_lookup_delay_seconds.unwrap_or(10); - let dns_ready = check_dns_txt_record(&dns_record, &dns_value, max_attempts, delay_seconds).await; - - if !dns_ready { - let error_msg = format!("DNS record not found after checking for domain {}. Record: {} IN TXT {}", domain, dns_record, dns_value); - warn!("{}. Please verify the DNS record is set correctly.", error_msg); - let storage = StorageFactory::create_default(config)?; - if let Err(record_err) = storage.record_failure(&error_msg).await { - warn!("Failed to record failure: {}", record_err); - } - return Ok(()); - } - - info!("DNS record found! Proceeding with challenge validation..."); - } - instant_acme::ChallengeType::TlsAlpn01 => todo!("TLS-ALPN-01 is not supported"), - _ => { - let error_msg = format!("Unsupported challenge type: {:?}", challenge_type); - warn!("{}", error_msg); - let storage = StorageFactory::create_default(config)?; - if let Err(record_err) = storage.record_failure(&error_msg).await { - warn!("Failed to record failure: {}", record_err); - } - return Ok(()); - } - } - - // Notify ACME server to validate - info!("Domain '{}': Notifying ACME server to validate challenge", domain); - challenge.set_ready().await - .with_context(|| format!("Failed to set challenge ready for domain {}", domain))?; - challenges_set.push(domain); - } - - if challenges_set.is_empty() { - let error_msg = format!("All domains failed challenge setup for domain {}", config.opts.domain); - warn!("{}", error_msg); - let storage = StorageFactory::create_default(config)?; - if let Err(record_err) = storage.record_failure(&error_msg).await { - warn!("Failed to record failure: {}", record_err); - } - return Ok(()); - } - - // Exponentially back off until the order becomes ready or invalid. - let mut tries = 0u8; - let state = loop { - let state = match order.refresh().await { - Ok(s) => s, - Err(e) => { - if tries >= 10 { - let error_msg = format!("Order refresh failed after {} attempts: {}", tries, e); - warn!("{}", error_msg); - let storage = StorageFactory::create_default(config)?; - if let Err(record_err) = storage.record_failure(&error_msg).await { - warn!("Failed to record failure: {}", record_err); - } - return Ok(()); - } - tries += 1; - tokio::time::sleep(std::time::Duration::from_secs(2)).await; - continue; - } - }; - - info!("Order state: {:#?}", state); - if let OrderStatus::Ready | OrderStatus::Invalid | OrderStatus::Valid = state.status { - break state; - } - - tries += 1; - if tries >= 10 { - let error_msg = format!("Giving up: order is not ready after {} attempts for domain {}", tries, config.opts.domain); - warn!("{}", error_msg); - let storage = StorageFactory::create_default(config)?; - if let Err(record_err) = storage.record_failure(&error_msg).await { - warn!("Failed to record failure: {}", record_err); - } - return Ok(()); - } - - let delay = std::time::Duration::from_secs(2 + tries as u64); - info!("order is not ready, waiting {delay:?}"); - tokio::time::sleep(delay).await; - }; - - if state.status == OrderStatus::Invalid { - // Try to get more details about why the order is invalid - let mut error_details = Vec::new(); - if let Some(error) = &state.error { - error_details.push(format!("Order error: {:?}", error)); - } - - // Fetch authorization details from ACME server if state is None - for auth in &state.authorizations { - if let Some(auth_state) = &auth.state { - // Check authorization status for more details - match &auth_state.status { - instant_acme::AuthorizationStatus::Invalid => { - error_details.push(format!("Authorization {} is invalid", auth.url)); - } - instant_acme::AuthorizationStatus::Expired => { - error_details.push(format!("Authorization {} expired", auth.url)); - } - instant_acme::AuthorizationStatus::Revoked => { - error_details.push(format!("Authorization {} revoked", auth.url)); - } - _ => {} - } - } else { - // Authorization state is None - this means the authorization details weren't included in the order state - // We can't fetch it again because order.authorizations() was already consumed - // Log the URL so the user can check it manually - warn!("Authorization state is None for {}. This usually means the authorization failed or expired. Check the authorization URL for details.", auth.url); - error_details.push(format!("Authorization {} state unavailable (check URL for details)", auth.url)); - } - } - - let error_msg = if error_details.is_empty() { - format!("Order is invalid but no error details available. Order state: {:#?}", state) - } else { - format!("Order is invalid. Details: {}", error_details.join("; ")) - }; - warn!("{}", error_msg); - let storage = StorageFactory::create_default(config)?; - if let Err(record_err) = storage.record_failure(&error_msg).await { - warn!("Failed to record failure: {}", record_err); - } - return Ok(()); - } - - state - }; - - // Check if state is invalid before proceeding to finalization - if state.status == OrderStatus::Invalid { - let error_msg = format!("Order is invalid for domain {}", config.opts.domain); - warn!("{}", error_msg); - let storage = StorageFactory::create_default(config)?; - if let Err(record_err) = storage.record_failure(&error_msg).await { - warn!("Failed to record failure: {}", record_err); - } - return Ok(()); - } - - // If the order is ready, we can provision the certificate. - // Finalize the order - this will generate a CSR and return the private key PEM. - let private_key_pem = order.finalize().await - .context("Failed to finalize ACME order")?; - - std::thread::sleep(std::time::Duration::from_secs(1)); - let mut tries = 1u8; - - let cert_chain_pem = loop { - match order.certificate().await { - Ok(Some(cert_chain_pem)) => { - info!("Certificate ready!"); - break cert_chain_pem; - } - Ok(None) => { - if tries > 10 { - let error_msg = format!("Giving up: certificate is still not ready after {} attempts", tries); - let storage = StorageFactory::create_default(config)?; - if let Err(record_err) = storage.record_failure(&error_msg).await { - warn!("Failed to record failure: {}", record_err); - } - return Err(anyhow!("{}", error_msg)); - } - tries += 1; - info!("Certificate not ready yet..."); - continue; - } - Err(e) => { - let error_msg = format!("Error getting certificate: {}", e); - let storage = StorageFactory::create_default(config)?; - if let Err(record_err) = storage.record_failure(&error_msg).await { - warn!("Failed to record failure: {}", record_err); - } - return Err(anyhow!("{}", error_msg)); - } - } - }; - - write_certs(config, cert_chain_pem, private_key_pem).await - .context("Failed to write certificates to storage")?; - - // Clear any previous failure records since certificate was successfully generated - let storage = StorageFactory::create_default(config)?; - if let Err(clear_err) = storage.clear_failure().await { - warn!("Failed to clear failure record: {}", clear_err); - } - - info!("HTTPS TLS Cert init successful! Certificate written to storage."); - - Ok(()) -} - -async fn write_certs( - config: &Config, - cert_chain_pem: String, - private_key_pem: String, -) -> AtomicServerResult<()> { - // Always use Redis storage (storage_type option is kept for compatibility but always uses Redis) - info!("Creating Redis storage backend"); - let storage = StorageFactory::create_default(config)?; - info!("Storage backend created successfully"); - - info!("Writing TLS certificates to storage (certbot-style)"); - - // Parse the certificate chain to separate cert from chain - // The cert_chain_pem contains the domain cert first, followed by intermediate certs - // It's already in PEM format, so we split it by "-----BEGIN CERTIFICATE-----" - let cert_parts: Vec = cert_chain_pem - .split("-----BEGIN CERTIFICATE-----") - .filter(|s| !s.trim().is_empty()) - .map(|s| format!("-----BEGIN CERTIFICATE-----{}", s)) - .collect(); - - if cert_parts.is_empty() { - return Err(anyhow!("No certificates found in chain")); - } - - // First certificate is the domain certificate - let domain_cert_pem = cert_parts[0].trim().to_string(); - - // Remaining certificates form the chain - let chain_pem = if cert_parts.len() > 1 { - cert_parts[1..].join("\n") - } else { - String::new() - }; - - // Combine cert and chain to create fullchain - let mut fullchain = domain_cert_pem.clone(); - if !chain_pem.is_empty() { - fullchain.push_str("\n"); - fullchain.push_str(&chain_pem); - } - - info!("Writing certificate to Redis storage backend..."); - storage.write_certs( - domain_cert_pem.as_bytes(), - chain_pem.as_bytes(), - private_key_pem.as_bytes(), - ).await - .context("Failed to write certificates to storage backend")?; - info!("Certificates written successfully to Redis storage backend"); - - storage.write_created_at(chrono::Utc::now()).await - .context("Failed to write created_at timestamp")?; - - // Save certificates to proxy_certificates path - if let Some(proxy_certificates_path) = get_proxy_certificates_path() { - if let Err(e) = save_cert_to_proxy_path( - &config.opts.domain, - &fullchain, - &private_key_pem, - &proxy_certificates_path, - ).await { - warn!("Failed to save certificate to proxy_certificates path: {}", e); - } else { - info!("Certificate saved to proxy_certificates path: {}", proxy_certificates_path); - } - } else { - warn!("proxy_certificates path not configured, skipping file save"); - } - info!("Created_at timestamp written successfully"); - - Ok(()) -} - -/// Start HTTP server for ACME challenge requests -/// This server only serves ACME challenge files and keeps running indefinitely -pub async fn start_http_server(app_config: &AppConfig) -> AtomicServerResult<()> { - let address = format!("{}:{}", app_config.server.ip, app_config.server.port); - info!("Starting HTTP server for ACME challenges at {}", address); - info!("Server will only accept ACME challenge requests at /.well-known/acme-challenge/*"); - info!("Certificate expiration check endpoints:"); - info!(" - GET /cert/expiration - Check all domains"); - info!(" - GET /cert/expiration/{{domain}} - Check specific domain"); - info!("To stop the program, press Ctrl+C"); - - // Use the base storage path for serving ACME challenges - // Challenges are stored in a shared location: https_path/well-known/acme-challenge/ - let base_static_path = std::path::PathBuf::from(&app_config.storage.https_path); - - // Build the path to the well-known/acme-challenge directory - // Files are stored at: base_path/well-known/acme-challenge/{token} - let mut challenge_static_path = base_static_path.clone(); - challenge_static_path.push("well-known"); - challenge_static_path.push("acme-challenge"); - - // Ensure the challenge directory exists (required for actix_files::Files) - // Even when using Redis storage, challenge files are still written to filesystem for HTTP-01 - tokio::fs::create_dir_all(&challenge_static_path) - .await - .with_context(|| format!("Failed to create challenge static path directory: {:?}", challenge_static_path))?; - - let base_https_path = base_static_path.clone(); - let app_config_data = web::Data::new(app_config.clone()); - let base_path_data = web::Data::new(base_https_path); - - // Create HTTP server that only serves ACME challenge files - // The server will serve from any domain's challenge directory - let server = HttpServer::new(move || { - App::new() - .app_data(app_config_data.clone()) - .app_data(base_path_data.clone()) - .service( - // Serve ACME challenges from the challenge directory - // URL: /.well-known/acme-challenge/{token} - // File: base_path/well-known/acme-challenge/{token} - // The Files service maps the URL path to the file system path - actix_files::Files::new("/.well-known/acme-challenge", challenge_static_path.clone()) - .prefer_utf8(true), - ) - .route( - "/cert/expiration", - web::get().to(check_all_certs_expiration_handler), - ) - .route( - "/cert/expiration/{domain}", - web::get().to(check_cert_expiration_handler), - ) - // Reject all other requests with 404 - .default_service(web::route().to(|| async { - HttpResponse::NotFound().body("Not Found") - })) - }) - .bind(&address) - .with_context(|| format!("Failed to bind HTTP server to {}", address))?; - - info!("HTTP server started successfully at {}", address); - - // Keep the server running indefinitely - server.run().await - .with_context(|| "HTTP server error")?; - - Ok(()) -} +//! Everything required for setting up HTTPS / TLS. +//! Instantiate a server for HTTP-01 check with letsencrypt, +//! checks if certificates are not outdated, +//! persists files on disk. + +use crate::acme::{Config, AppConfig, RetryConfig, AtomicServerResult}; +use crate::acme::{DomainConfig, DomainReaderFactory}; +use crate::acme::{Storage, StorageFactory, StorageType}; + +use actix_web::{App, HttpServer, HttpResponse, web, Responder}; +use anyhow::{anyhow, Context}; +use serde::Serialize; +use std::io::BufReader; +use std::sync::Arc; +use std::sync::RwLock as StdRwLock; +use once_cell::sync::OnceCell; +use tracing::{info, warn, debug}; +use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; + +/// Global proxy_certificates path (set at startup) +static PROXY_CERTIFICATES_PATH: OnceCell>>> = OnceCell::new(); + +/// Set the proxy_certificates path (called from main.rs) +pub fn set_proxy_certificates_path(path: Option) { + let path_arc = PROXY_CERTIFICATES_PATH.get_or_init(|| { + Arc::new(StdRwLock::new(None)) + }); + if let Ok(mut path_guard) = path_arc.write() { + *path_guard = path; + } +} + +/// Get the proxy_certificates path +fn get_proxy_certificates_path() -> Option { + PROXY_CERTIFICATES_PATH.get() + .and_then(|path_arc| { + path_arc.read().ok() + .and_then(|guard| guard.clone()) + }) +} + +/// Normalize PEM certificate chain to ensure proper format +/// - Ensures newline between certificates (END CERTIFICATE and BEGIN CERTIFICATE) +/// - Ensures file ends with newline +fn normalize_pem_chain(chain: &str) -> String { + let mut normalized = chain.to_string(); + + // Ensure newline between END CERTIFICATE and BEGIN CERTIFICATE + normalized = normalized.replace("-----END CERTIFICATE----------BEGIN CERTIFICATE-----", + "-----END CERTIFICATE-----\n-----BEGIN CERTIFICATE-----"); + + // Ensure newline between END CERTIFICATE and BEGIN PRIVATE KEY (for key files) + normalized = normalized.replace("-----END CERTIFICATE----------BEGIN PRIVATE KEY-----", + "-----END CERTIFICATE-----\n-----BEGIN PRIVATE KEY-----"); + + // Ensure file ends with newline + if !normalized.ends_with('\n') { + normalized.push('\n'); + } + + normalized +} + +/// Save certificate to proxy_certificates path in the format expected by the proxy +/// Format: {sanitized_domain}.crt and {sanitized_domain}.key +async fn save_cert_to_proxy_path( + domain: &str, + fullchain: &str, + private_key: &str, + proxy_certificates_path: &str, +) -> anyhow::Result<()> { + use std::path::Path; + use tokio::fs; + use tokio::io::AsyncWriteExt; + + // Create directory if it doesn't exist + let cert_dir = Path::new(proxy_certificates_path); + fs::create_dir_all(cert_dir).await + .with_context(|| format!("Failed to create proxy_certificates directory: {}", proxy_certificates_path))?; + + // Normalize domain name (remove wildcard prefix if present) before sanitizing + // This ensures the filename matches what the certificate worker expects + let normalized_domain = domain.strip_prefix("*.").unwrap_or(domain); + // Sanitize domain name for filename (replace . with _ and * with _) + let sanitized_domain = normalized_domain.replace('.', "_").replace('*', "_"); + let cert_path = cert_dir.join(format!("{}.crt", sanitized_domain)); + let key_path = cert_dir.join(format!("{}.key", sanitized_domain)); + + // Normalize PEM format + let normalized_fullchain = normalize_pem_chain(fullchain); + let normalized_key = normalize_pem_chain(private_key); + + // Write certificate file + let mut cert_file = fs::File::create(&cert_path).await + .with_context(|| format!("Failed to create certificate file: {}", cert_path.display()))?; + cert_file.write_all(normalized_fullchain.as_bytes()).await + .with_context(|| format!("Failed to write certificate file: {}", cert_path.display()))?; + cert_file.sync_all().await + .with_context(|| format!("Failed to sync certificate file: {}", cert_path.display()))?; + + // Write key file + let mut key_file = fs::File::create(&key_path).await + .with_context(|| format!("Failed to create key file: {}", key_path.display()))?; + key_file.write_all(normalized_key.as_bytes()).await + .with_context(|| format!("Failed to write key file: {}", key_path.display()))?; + key_file.sync_all().await + .with_context(|| format!("Failed to sync key file: {}", key_path.display()))?; + + info!("Saved certificate for domain '{}' to proxy_certificates path: {} (cert: {}, key: {})", + domain, proxy_certificates_path, cert_path.display(), key_path.display()); + + Ok(()) +} + +/// Create RUSTLS server config from certificates in storage +pub fn get_https_config( + config: &Config, +) -> AtomicServerResult { + use rustls_pemfile::{certs, pkcs8_private_keys}; + use rustls::pki_types::{CertificateDer, PrivateKeyDer}; + + // Create storage backend (file system by default) + let storage = StorageFactory::create_default(config)?; + + // Read fullchain synchronously (rustls requires sync) + // Use fullchain which includes both cert and chain + let fullchain_bytes = storage.read_fullchain_sync() + .ok_or_else(|| anyhow!("Storage backend does not support synchronous fullchain reading"))??; + + let key_bytes = storage.read_key_sync() + .ok_or_else(|| anyhow!("Storage backend does not support synchronous key reading"))??; + + let cert_file = &mut BufReader::new(std::io::Cursor::new(fullchain_bytes)); + let key_file = &mut BufReader::new(std::io::Cursor::new(key_bytes)); + + let mut cert_chain = Vec::new(); + for cert_result in certs(cert_file) { + let cert = cert_result.context("Failed to parse certificate")?; + cert_chain.push(CertificateDer::from(cert)); + } + + let mut keys: Vec = pkcs8_private_keys(key_file) + .collect::, _>>() + .context("Failed to parse private key")? + .into_iter() + .map(PrivateKeyDer::Pkcs8) + .collect(); + + if keys.is_empty() { + return Err(anyhow!("No key found. Consider deleting the storage directory and restart to create new keys.")); + } + + let server_config = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(cert_chain, keys.remove(0)) + .context("Unable to create HTTPS config from certificates")?; + + Ok(server_config) +} + +/// Check if a failed certificate should be retried based on exponential backoff +pub async fn should_retry_failed_cert( + config: &Config, + retry_config: &RetryConfig, +) -> AtomicServerResult { + let storage = StorageFactory::create_default(config)?; + + // Check if there's a failure record + let last_failure = match storage.get_last_failure().await { + Ok(Some((timestamp, _))) => timestamp, + Ok(None) => return Ok(false), // No failure recorded + Err(e) => { + warn!("Failed to read failure record: {}", e); + return Ok(false); + } + }; + + // Check if max retries exceeded + let failure_count = storage.get_failure_count().await.unwrap_or(0); + if retry_config.max_retries > 0 && failure_count >= retry_config.max_retries { + warn!("Maximum retry count ({}) exceeded for domain {}. Skipping retry.", retry_config.max_retries, config.opts.domain); + return Ok(false); + } + + // Calculate exponential backoff delay + // Formula: min(min_retry_delay * 2^(failure_count - 1), max_retry_delay) + let base_delay = retry_config.min_retry_delay_seconds as f64; + let exponential_delay = base_delay * (2.0_f64.powi((failure_count.saturating_sub(1)) as i32)); + let delay_seconds = exponential_delay.min(retry_config.max_retry_delay_seconds as f64) as u64; + + let now = chrono::Utc::now(); + let time_since_failure = now - last_failure; + let time_since_failure_secs = time_since_failure.num_seconds() as u64; + + if time_since_failure_secs >= delay_seconds { + info!("Retry delay ({}) has passed for domain {}. Last failure was {} seconds ago. Will retry.", delay_seconds, config.opts.domain, time_since_failure_secs); + Ok(true) + } else { + let remaining = delay_seconds - time_since_failure_secs; + info!("Retry delay not yet reached for domain {}. Will retry in {} seconds.", config.opts.domain, remaining); + Ok(false) + } +} + +/// Checks if the certificates need to be renewed. +/// Will be true if there are no certs yet. +pub async fn should_renew_certs_check(config: &Config) -> AtomicServerResult { + let storage = StorageFactory::create_default(config)?; + + if !storage.cert_exists().await { + info!( + "No HTTPS certificates found, requesting new ones...", + ); + return Ok(true); + } + + // Ensure certificate hash exists (generate if missing for backward compatibility) + if let Err(e) = storage.get_certificate_hash().await { + warn!("Failed to get or generate certificate hash: {}", e); + } + + let created_at = match storage.read_created_at().await { + Ok(dt) => dt, + Err(_) => { + // If we can't read the created_at file, assume certificates need renewal + warn!("Unable to read certificate creation timestamp, assuming renewal needed"); + return Ok(true); + } + }; + + let certs_age: chrono::Duration = chrono::Utc::now() - created_at; + // Let's Encrypt certificates are valid for three months, but I think renewing earlier provides a better UX + let expired = certs_age > chrono::Duration::weeks(4); + if expired { + warn!("HTTPS Certificates expired, requesting new ones...") + }; + Ok(expired) +} + +#[derive(Debug, Serialize)] +struct CertificateExpirationInfo { + domain: String, + exists: bool, + created_at: Option, + expires_at: Option, + age_days: Option, + expires_in_days: Option, + needs_renewal: bool, + #[serde(default)] + renewing: bool, +} + +/// Get certificate expiration information for a domain +async fn get_cert_expiration_info( + app_config: &AppConfig, + domain: &str, + base_https_path: &std::path::PathBuf, +) -> anyhow::Result { + let domain_cfg = { + let domain_config = DomainConfig { + domain: domain.to_string(), + email: None, + dns: false, + wildcard: false, + }; + app_config.create_domain_config(&domain_config, base_https_path.clone()) + }; + + let storage = StorageFactory::create_default(&domain_cfg)?; + let exists = storage.cert_exists().await; + + if !exists { + return Ok(CertificateExpirationInfo { + domain: domain.to_string(), + exists: false, + created_at: None, + expires_at: None, + age_days: None, + expires_in_days: None, + needs_renewal: true, + renewing: false, + }); + } + + // Ensure certificate hash exists (generate if missing for backward compatibility) + if let Err(e) = storage.get_certificate_hash().await { + warn!("Failed to get or generate certificate hash for {}: {}", domain, e); + } + + let created_at = match storage.read_created_at().await { + Ok(dt) => dt, + Err(_) => { + return Ok(CertificateExpirationInfo { + domain: domain.to_string(), + exists: true, + created_at: None, + expires_at: None, + age_days: None, + expires_in_days: None, + needs_renewal: true, + renewing: false, + }); + } + }; + + // Let's Encrypt certificates are valid for 90 days (3 months) + let expires_at = created_at + chrono::Duration::days(90); + let now = chrono::Utc::now(); + let age = now - created_at; + let expires_in = expires_at - now; + + let needs_renewal = age > chrono::Duration::weeks(4); + + Ok(CertificateExpirationInfo { + domain: domain.to_string(), + exists: true, + created_at: Some(created_at.to_rfc3339()), + expires_at: Some(expires_at.to_rfc3339()), + age_days: Some(age.num_days()), + expires_in_days: Some(expires_in.num_days()), + needs_renewal, + renewing: false, + }) +} + +/// HTTP handler for certificate expiration check (single domain) +async fn check_cert_expiration_handler( + app_config: web::Data, + base_path: web::Data, + path: web::Path, +) -> impl Responder { + let domain = path.into_inner(); + match get_cert_expiration_info(&app_config, &domain, &base_path).await { + Ok(mut info) => { + // If certificate needs renewal, start renewal process in background + if info.needs_renewal { + // Read domains to find the domain config + let domain_reader = match DomainReaderFactory::create(&app_config.domains) { + Ok(reader) => reader, + Err(e) => { + warn!("Error creating domain reader: {}", e); + return HttpResponse::Ok().json(info); + } + }; + + if let Ok(domains) = domain_reader.read_domains().await { + if let Some(domain_config) = domains.iter().find(|d| d.domain == domain) { + let app_config_clone = app_config.clone(); + let base_path_clone = base_path.clone(); + let domain_config_clone = domain_config.clone(); + + // Spawn renewal task in background + tokio::spawn(async move { + if let Err(e) = renew_cert_if_needed(&app_config_clone, &domain_config_clone, &base_path_clone).await { + warn!("Error renewing certificate for {}: {}", domain_config_clone.domain, e); + } + }); + + info.renewing = true; // Mark as renewing + } + } + } + HttpResponse::Ok().json(info) + } + Err(e) => { + warn!("Error checking certificate expiration for {}: {}", domain, e); + HttpResponse::InternalServerError().json(serde_json::json!({ + "error": format!("Failed to check certificate expiration: {}", e) + })) + } + } +} + +/// Renew certificate for a domain if needed +async fn renew_cert_if_needed( + app_config: &AppConfig, + domain_config: &DomainConfig, + base_path: &std::path::PathBuf, +) -> anyhow::Result<()> { + let domain_cfg = app_config.create_domain_config(domain_config, base_path.clone()); + + if should_renew_certs_check(&domain_cfg).await? { + info!("Certificate for {} is expiring, starting renewal process...", domain_config.domain); + request_cert(&domain_cfg).await?; + info!("Certificate renewed successfully for {}!", domain_config.domain); + } + + Ok(()) +} + +/// HTTP handler for checking expiration of all domains +async fn check_all_certs_expiration_handler( + app_config: web::Data, + base_path: web::Data, +) -> impl Responder { + // Read domains from the configured source + let domain_reader = match DomainReaderFactory::create(&app_config.domains) { + Ok(reader) => reader, + Err(e) => { + warn!("Error creating domain reader: {}", e); + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": format!("Failed to create domain reader: {}", e) + })); + } + }; + + let domains = match domain_reader.read_domains().await { + Ok(domains) => domains, + Err(e) => { + warn!("Error reading domains: {}", e); + return HttpResponse::InternalServerError().json(serde_json::json!({ + "error": format!("Failed to read domains: {}", e) + })); + } + }; + + // Check expiration for each domain and renew if needed + let mut results = Vec::new(); + for domain_config in domains.iter() { + match get_cert_expiration_info(&app_config, &domain_config.domain, &base_path).await { + Ok(mut info) => { + // If certificate needs renewal, start renewal process in background + if info.needs_renewal { + let app_config_clone = app_config.clone(); + let base_path_clone = base_path.clone(); + let domain_config_clone = domain_config.clone(); + + // Spawn renewal task in background + tokio::spawn(async move { + if let Err(e) = renew_cert_if_needed(&app_config_clone, &domain_config_clone, &base_path_clone).await { + warn!("Error renewing certificate for {}: {}", domain_config_clone.domain, e); + } + }); + + info.renewing = true; // Mark as renewing + } + results.push(info); + } + Err(e) => { + warn!("Error checking certificate expiration for {}: {}", domain_config.domain, e); + // Add error info for this domain + results.push(CertificateExpirationInfo { + domain: domain_config.domain.clone(), + exists: false, + created_at: None, + expires_at: None, + age_days: None, + expires_in_days: None, + needs_renewal: true, + renewing: false, + }); + } + } + } + + HttpResponse::Ok().json(results) +} + +/// Check DNS TXT record for DNS-01 challenge +async fn check_dns_txt_record(record_name: &str, expected_value: &str, max_attempts: u32, delay_seconds: u64) -> bool { + use trust_dns_resolver::TokioAsyncResolver; + + // Use Google DNS as primary resolver (more reliable than system DNS) + // This ensures we're querying authoritative DNS servers + let resolver_config = ResolverConfig::google(); + + info!("Checking DNS TXT record: {} (expected value: {})", record_name, expected_value); + info!("DNS lookup settings: max_attempts={}, delay_seconds={}", max_attempts, delay_seconds); + + for attempt in 1..=max_attempts { + // Create a new resolver for each attempt to ensure no caching + let mut resolver_opts = ResolverOpts::default(); + resolver_opts.use_hosts_file = true; + resolver_opts.validate = false; // Don't validate DNSSEC to avoid issues + resolver_opts.attempts = 3; // Retry attempts per query + resolver_opts.timeout = std::time::Duration::from_secs(5); // 5 second timeout + resolver_opts.cache_size = 0; // Disable DNS cache by setting cache size to 0 + + // Create a fresh DNS resolver for each attempt to avoid any caching + let resolver = TokioAsyncResolver::tokio( + resolver_config.clone(), + resolver_opts, + ); + + match resolver.txt_lookup(record_name).await { + Ok(lookup) => { + let mut found_any = false; + let mut found_values = Vec::new(); + + // Check if any TXT record matches the expected value + for record in lookup.iter() { + for txt_data in record.iter() { + let txt_string = String::from_utf8_lossy(txt_data).trim().to_string(); + found_any = true; + found_values.push(txt_string.clone()); + + if txt_string == expected_value { + info!("DNS TXT record matches expected value on attempt {}: {}", attempt, txt_string); + return true; + } + } + } + + if found_any { + if attempt == 1 || attempt % 6 == 0 { + warn!("DNS record found but value doesn't match. Expected: '{}', Found: {:?}", expected_value, found_values); + } + } else { + if attempt % 6 == 0 { + info!("DNS record not found yet (attempt {}/{})...", attempt, max_attempts); + } + } + } + Err(e) => { + if attempt == 1 || attempt % 6 == 0 { + warn!("DNS lookup error on attempt {}: {}", attempt, e); + } + } + } + + if attempt < max_attempts { + tokio::time::sleep(tokio::time::Duration::from_secs(delay_seconds)).await; + } + } + + warn!("DNS TXT record not found after {} attempts", max_attempts); + false +} + +/// Check if ACME challenge endpoint is available +/// This verifies that the ACME server is running and accessible before requesting certificates +/// Retries with exponential backoff to handle cases where the server is still starting +async fn check_acme_challenge_endpoint(config: &Config) -> anyhow::Result<()> { + use std::time::Duration; + + // Build the ACME server URL (typically 127.0.0.1:9180) + let acme_url = format!("http://{}:{}/.well-known/acme-challenge/test-endpoint-check", config.opts.ip, config.opts.port); + + debug!("Checking if ACME challenge endpoint is available at {}", acme_url); + + // Create HTTP client with timeout + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(5)) + .build() + .context("Failed to create HTTP client for endpoint check")?; + + // Retry logic: try up to 5 times with exponential backoff + // This handles cases where the ACME server is still starting up + let max_retries = 5; + let mut retry_delay = Duration::from_millis(10); // Start with 500ms + + for attempt in 1..=max_retries { + match client.get(&acme_url).send().await { + Ok(response) => { + let status = response.status(); + if status.is_success() || status.as_u16() == 404 { + if attempt > 1 { + debug!("ACME challenge endpoint is now available (status: {}) after {} attempt(s)", status, attempt); + } else { + debug!("ACME challenge endpoint is available (status: {})", status); + } + return Ok(()); + } else { + return Err(anyhow::anyhow!("ACME server returned unexpected status: {}", status)); + } + } + Err(e) => { + let error_msg = e.to_string(); + let is_connection_error = error_msg.contains("Connection refused") + || error_msg.contains("connect") + || error_msg.contains("connection") + || error_msg.contains("refused") + || e.is_connect() + || e.is_timeout(); + + if attempt < max_retries && is_connection_error { + debug!("ACME server not ready yet (attempt {}/{}), retrying in {:?}...", attempt, max_retries, retry_delay); + tokio::time::sleep(retry_delay).await; + // Exponential backoff: 10ms, 20ms, 40ms, 80ms, 160ms (user changed from 500ms) + retry_delay = retry_delay * 2; + continue; + } + // Last attempt or non-connection error - return error + if attempt >= max_retries { + return Err(anyhow::anyhow!("Failed to connect to ACME server at {} after {} attempts: {}", acme_url, max_retries, e)); + } else { + return Err(anyhow::anyhow!("Failed to connect to ACME server at {} (non-retryable error): {}", acme_url, e)); + } + } + } + } + + Err(anyhow::anyhow!("Failed to connect to ACME server at {} after {} attempts", acme_url, max_retries)) +} + +/// Writes challenge file for HTTP-01 challenge +/// The main HTTP server will serve this file - no temporary server needed +async fn cert_init_server( + config: &Config, + challenge: &instant_acme::Challenge, + key_auth: &str, +) -> AtomicServerResult<()> { + let storage = StorageFactory::create_default(config)?; + storage.write_challenge(&challenge.token.to_string(), key_auth).await?; + + info!("Challenge file written. Main HTTP server will serve it at /.well-known/acme-challenge/{}", challenge.token); + + Ok(()) +} + +/// Sends a request to LetsEncrypt to create a certificate +pub async fn request_cert(config: &Config) -> AtomicServerResult<()> { + // Always use Redis storage (storage_type option is kept for compatibility but always uses Redis) + let storage_type = StorageType::Redis; + + if storage_type == StorageType::Redis { + // Use distributed lock for Redis storage to prevent multiple instances from processing the same domain + // Create RedisStorage directly to access lock methods + let redis_storage = crate::acme::storage::RedisStorage::new(config)?; + + // Lock TTL from config (default: 900 seconds = 15 minutes) + let lock_ttl_seconds = config.opts.lock_ttl_seconds.unwrap_or(900); + + return redis_storage.with_lock(lock_ttl_seconds, || async { + request_cert_internal(config).await + }).await; + } + + // Redis storage always uses distributed lock (above) + request_cert_internal(config).await +} + +/// Parse retry-after timestamp from rate limit error message +/// Returns the retry-after timestamp if found, None otherwise +/// Handles both timestamp formats and ISO 8601 duration formats +fn parse_retry_after(error_msg: &str) -> Option> { + use chrono::{Utc, Duration}; + + // Look for "retry after" pattern (case insensitive) + let error_msg_lower = error_msg.to_lowercase(); + if let Some(pos) = error_msg_lower.find("retry after") { + // Get the text after "retry after" from the original message (preserve case for parsing) + let after_pos = error_msg[pos + "retry after".len()..].find(|c: char| !c.is_whitespace()) + .unwrap_or(0); + let mut after_text_str = error_msg[pos + "retry after".len() + after_pos..].trim().to_string(); + + // Try to find the end of the timestamp/duration + // For timestamps like "2025-11-14 21:13:29 UTC", stop at end of line or before URL/links + // Look for common patterns that indicate end of timestamp: + // - End of string + // - Before URLs (http:// or https://) + // - Before "see" or ":" followed by URL + if let Some(url_pos) = after_text_str.find("http://").or_else(|| after_text_str.find("https://")) { + after_text_str = after_text_str[..url_pos].trim().to_string(); + } + // Extract timestamp - format is typically "2025-11-14 21:13:29 UTC" followed by ": see https://..." + // Simplest approach: find " UTC" and take everything up to and including it + if let Some(utc_pos) = after_text_str.find(" UTC") { + // Found " UTC", extract up to and including it (this is the complete timestamp) + after_text_str = after_text_str[..utc_pos + 4].trim().to_string(); + } else { + // No " UTC" found, try to stop before URLs or "see" keyword + if let Some(url_pos) = after_text_str.find("http://").or_else(|| after_text_str.find("https://")) { + after_text_str = after_text_str[..url_pos].trim().to_string(); + } + if let Some(see_pos) = after_text_str.find(" see ") { + after_text_str = after_text_str[..see_pos].trim().to_string(); + } + } + + let after_text = after_text_str.as_str(); + + // First, try to parse as timestamp (format: "2025-11-10 18:08:38 UTC") + // Try with timezone first + if let Ok(dt) = chrono::DateTime::parse_from_str(after_text, "%Y-%m-%d %H:%M:%S %Z") { + return Some(dt.with_timezone(&chrono::Utc)); + } + // Try with "UTC" as separate word (common format: "2025-11-14 21:13:29 UTC") + if after_text.ends_with(" UTC") { + let without_tz = &after_text[..after_text.len() - 4].trim(); + if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(without_tz, "%Y-%m-%d %H:%M:%S") { + return Some(chrono::DateTime::from_naive_utc_and_offset(dt, chrono::Utc)); + } + } + // Try alternative format without timezone (assume UTC) + if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(after_text, "%Y-%m-%d %H:%M:%S") { + return Some(chrono::DateTime::from_naive_utc_and_offset(dt, chrono::Utc)); + } + // Try parsing as RFC3339 format + if let Ok(dt) = after_text.parse::>() { + return Some(dt); + } + // Try ISO 8601/RFC3339 format (chrono doesn't have parse_from_rfc3339, use parse_from_str with RFC3339 format) + // RFC3339 format: "2025-11-14T21:13:29Z" or "2025-11-14T21:13:29+00:00" + if let Ok(dt) = chrono::DateTime::parse_from_str(after_text, "%+") { + return Some(dt.with_timezone(&chrono::Utc)); + } + + // Try parsing as ISO 8601 duration (e.g., "PT86225.992004616S" or "PT24H") + // This happens when the error message contains a duration instead of a timestamp + if after_text.starts_with("PT") { + // Parse ISO 8601 duration: PT[nH][nM][nS] or PT[n]S + // Handle case where duration ends with 'S' (seconds) + let duration_str = if after_text.ends_with('S') && !after_text.ends_with("MS") && !after_text.ends_with("HS") { + &after_text[2..after_text.len()-1] // Remove "PT" prefix and "S" suffix + } else { + &after_text[2..] // Just remove "PT" prefix + }; + + // Try to parse as seconds (e.g., "86225.992004616") + if let Ok(seconds) = duration_str.parse::() { + let duration = Duration::seconds(seconds as i64) + Duration::nanoseconds((seconds.fract() * 1_000_000_000.0) as i64); + return Some(Utc::now() + duration); + } + + // Try to parse hours, minutes, seconds separately + let mut total_seconds = 0.0; + let mut current_num = String::new(); + let mut current_unit = String::new(); + + for ch in duration_str.chars() { + if ch.is_ascii_digit() || ch == '.' { + if !current_unit.is_empty() { + // Process previous unit + if let Ok(val) = current_num.parse::() { + match current_unit.as_str() { + "H" => total_seconds += val * 3600.0, + "M" => total_seconds += val * 60.0, + "S" => total_seconds += val, + _ => {} + } + } + current_num.clear(); + current_unit.clear(); + } + current_num.push(ch); + } else if ch.is_ascii_alphabetic() { + current_unit.push(ch); + } + } + + // Process last unit + if !current_unit.is_empty() && !current_num.is_empty() { + if let Ok(val) = current_num.parse::() { + match current_unit.as_str() { + "H" => total_seconds += val * 3600.0, + "M" => total_seconds += val * 60.0, + "S" => total_seconds += val, + _ => {} + } + } + } + + if total_seconds > 0.0 { + let duration = Duration::seconds(total_seconds as i64) + Duration::nanoseconds((total_seconds.fract() * 1_000_000_000.0) as i64); + return Some(Utc::now() + duration); + } + } + } + None +} + +/// Helper function to check if an account already exists +async fn check_account_exists( + email: &str, + lets_encrypt_url: &str, +) -> Result, anyhow::Error> { + match instant_acme::Account::builder() + .context("Failed to create account builder")? + .create( + &instant_acme::NewAccount { + contact: &[&format!("mailto:{}", email)], + terms_of_service_agreed: true, + only_return_existing: true, + }, + lets_encrypt_url.to_string(), + None, + ) + .await + { + Ok((acc, cr)) => Ok(Some((acc, cr))), + Err(e) => { + let error_msg = format!("{}", e); + // If it's a rate limit error, propagate it + if error_msg.contains("rateLimited") || error_msg.contains("rate limit") || error_msg.contains("too many") { + return Err(e.into()); + } + // Otherwise, account doesn't exist + Ok(None) + } + } +} + +/// Helper function to create a new Let's Encrypt account and save credentials +/// Handles rate limits by waiting for the retry-after time +async fn create_new_account( + storage: &Box, + email: &str, + lets_encrypt_url: &str, +) -> AtomicServerResult<(instant_acme::Account, instant_acme::AccountCredentials)> { + // First, check if account already exists + match check_account_exists(email, lets_encrypt_url).await { + Ok(Some((acc, cr))) => { + info!("Account already exists for email {}, reusing it", email); + return Ok((acc, cr)); + } + Ok(None) => { + // Account doesn't exist, proceed to create + } + Err(e) => { + // Check if it's a rate limit error + let error_msg = format!("{}", e); + if error_msg.contains("rateLimited") || error_msg.contains("rate limit") || error_msg.contains("too many") { + if let Some(retry_after) = parse_retry_after(&error_msg) { + let now = chrono::Utc::now(); + if retry_after > now { + let wait_duration = retry_after - now; + let wait_secs = wait_duration.num_seconds().max(0) as u64; + warn!("Rate limit hit. Waiting {} seconds until {} before retrying account creation", wait_secs, retry_after); + tokio::time::sleep(tokio::time::Duration::from_secs(wait_secs + 1)).await; + } + } else { + // Rate limit error but couldn't parse retry-after, wait a default time + warn!("Rate limit hit but couldn't parse retry-after time. Waiting 3 hours (10800 seconds) before retrying"); + tokio::time::sleep(tokio::time::Duration::from_secs(10800)).await; + } + } else { + // Not a rate limit error, propagate it + return Err(e); + } + } + } + + info!("Creating new LetsEncrypt account with email {}", email); + + // Retry account creation (after waiting for rate limit if needed) + let max_retries = 3; + let mut retry_count = 0; + + loop { + match instant_acme::Account::builder() + .context("Failed to create account builder")? + .create( + &instant_acme::NewAccount { + contact: &[&format!("mailto:{}", email)], + terms_of_service_agreed: true, + only_return_existing: false, + }, + lets_encrypt_url.to_string(), + None, + ) + .await + { + Ok((account, creds)) => { + // Save credentials for future use (store as JSON value for now) + if let Ok(creds_json) = serde_json::to_string(&creds) { + if let Err(e) = storage.write_account_credentials(&creds_json).await { + warn!("Failed to save account credentials to storage: {}. Account will be recreated on next run.", e); + } else { + info!("Saved LetsEncrypt account credentials to storage"); + } + } else { + warn!("Failed to serialize account credentials. Account will be recreated on next run."); + } + return Ok((account, creds)); + } + Err(e) => { + let error_msg = format!("{}", e); + + // Check if it's a rate limit error + if error_msg.contains("rateLimited") || error_msg.contains("rate limit") || error_msg.contains("too many") { + if let Some(retry_after) = parse_retry_after(&error_msg) { + let now = chrono::Utc::now(); + if retry_after > now { + let wait_duration = retry_after - now; + let wait_secs = wait_duration.num_seconds().max(0) as u64; + warn!("Rate limit hit during account creation. Waiting {} seconds until {} before retrying", wait_secs, retry_after); + tokio::time::sleep(tokio::time::Duration::from_secs(wait_secs + 1)).await; + retry_count += 1; + if retry_count < max_retries { + continue; + } + } + } else { + // Rate limit error but couldn't parse retry-after + if retry_count < max_retries { + let wait_secs = 10800; // 3 hours default + warn!("Rate limit hit but couldn't parse retry-after time. Waiting {} seconds before retrying", wait_secs); + tokio::time::sleep(tokio::time::Duration::from_secs(wait_secs)).await; + retry_count += 1; + continue; + } + } + } + + // If we've exhausted retries or it's not a rate limit error, return the error + return Err(e).context("Failed to create account"); + } + } + } +} + +async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { + use instant_acme::OrderStatus; + + // Detect wildcard domain and automatically use DNS-01 + let is_wildcard = config.opts.domain.starts_with("*."); + let use_dns = config.opts.https_dns || is_wildcard; + + if is_wildcard && !config.opts.https_dns { + warn!("Wildcard domain detected ({}), automatically using DNS-01 challenge", config.opts.domain); + } + + let challenge_type = if use_dns { + debug!("Using DNS-01 challenge"); + instant_acme::ChallengeType::Dns01 + } else { + debug!("Using HTTP-01 challenge"); + // Check if ACME challenge endpoint is available before proceeding + if let Err(e) = check_acme_challenge_endpoint(config).await { + let error_msg = format!("ACME challenge endpoint not available for HTTP-01 challenge: {}. Skipping certificate request.", e); + warn!("{}", error_msg); + let storage = StorageFactory::create_default(config)?; + if let Err(record_err) = storage.record_failure(&error_msg).await { + warn!("Failed to record failure: {}", record_err); + } + return Ok(()); + } + instant_acme::ChallengeType::Http01 + }; + + // Create a new account. This will generate a fresh ECDSA key for you. + // Alternatively, restore an account from serialized credentials by + // using `Account::from_credentials()`. + + let lets_encrypt_url = if config.opts.development { + warn!( + "Using LetsEncrypt staging server, not production. This is for testing purposes only and will not provide a working certificate." + ); + instant_acme::LetsEncrypt::Staging.url() + } else { + instant_acme::LetsEncrypt::Production.url() + }; + + let email = + config.opts.email.clone().expect( + "No email set - required for HTTPS certificate initialization with LetsEncrypt", + ); + + // Try to load existing account credentials from storage + let storage = StorageFactory::create_default(config)?; + let existing_creds = storage.read_account_credentials().await + .context("Failed to read account credentials from storage")?; + + // Try to restore account from stored credentials, but fall back to creating new account if it fails + let (account, _creds) = match existing_creds { + Some(creds_json) => { + // Try to restore account from existing credentials + debug!("Attempting to restore LetsEncrypt account from stored credentials"); + + // First try to parse and restore from stored credentials + match serde_json::from_str::(&creds_json) { + Ok(creds) => { + // Try to restore account from credentials + // Use AccountBuilder to restore from credentials + match instant_acme::Account::builder() + .context("Failed to create account builder")? + .from_credentials(creds) + .await + { + Ok(acc) => { + debug!("Successfully restored LetsEncrypt account from stored credentials"); + // Get the credentials back from the account (they're stored in the account) + // For now, we'll use the stored credentials JSON + let restored_creds = serde_json::from_str::(&creds_json) + .expect("Credentials were just parsed successfully"); + (acc, restored_creds) + } + Err(e) => { + let error_msg = format!("{}", e); + warn!("Failed to restore account from stored credentials: {}. Will check if account exists.", error_msg); + + // If restoration fails, check if account exists + match check_account_exists(&email, lets_encrypt_url).await { + Ok(Some((acc, cr))) => { + info!("Account exists but credentials were invalid. Using existing account."); + (acc, cr) + } + Ok(None) => { + warn!("Stored credentials invalid and account doesn't exist. Creating new account."); + create_new_account(&storage, &email, lets_encrypt_url).await? + } + Err(e) => { + let error_msg = format!("{}", e); + if error_msg.contains("rateLimited") || error_msg.contains("rate limit") || error_msg.contains("too many") { + warn!("Rate limit hit while checking account. Will wait and retry in create_new_account."); + create_new_account(&storage, &email, lets_encrypt_url).await? + } else { + warn!("Failed to check account existence: {}. Creating new account.", e); + create_new_account(&storage, &email, lets_encrypt_url).await? + } + } + } + } + } + } + Err(e) => { + warn!("Failed to parse stored credentials: {}. Creating new account.", e); + create_new_account(&storage, &email, lets_encrypt_url).await? + } + } + } + None => { + // No stored credentials, create a new account + create_new_account(&storage, &email, lets_encrypt_url).await? + } + }; + + // Create the ACME order based on the given domain names. + // Note that this only needs an `&Account`, so the library will let you + // process multiple orders in parallel for a single account. + + // Prepare domain for ACME order + // For wildcard domains (*.example.com), we need to request *.example.com + // For non-wildcard domains with DNS-01, we request the domain as-is + let domain = config.opts.domain.clone(); + let is_wildcard_domain = domain.starts_with("*."); + + if is_wildcard_domain { + // Domain already has wildcard prefix, use as-is for ACME order + // ACME requires *.example.com format for wildcard certificates + debug!("Requesting wildcard certificate for: {}", domain); + } else if use_dns { + // Non-wildcard domain with DNS-01 challenge - use domain as-is + debug!("Requesting certificate for domain with DNS-01: {}", domain); + } else { + // HTTP-01 challenge - use domain as-is + debug!("Requesting certificate for domain with HTTP-01: {}", domain); + } + + // Check if we're still in rate limit period before attempting request + use chrono::{Utc, Duration}; + let storage = StorageFactory::create_default(config)?; + if let Ok(Some((last_failure_time, last_failure_msg))) = storage.get_last_failure().await { + // Check if the last failure was a rate limit error + if last_failure_msg.contains("rateLimited") || + last_failure_msg.contains("rate limit") || + last_failure_msg.contains("too many certificates") { + // Parse retry-after time from error message + if let Some(retry_after) = parse_retry_after(&last_failure_msg) { + let now = Utc::now(); + if now < retry_after { + let wait_duration = retry_after - now; + info!("Rate limit still active for domain {}: retry after {} ({} remaining). Skipping certificate request.", + config.opts.domain, retry_after, wait_duration); + return Ok(()); + } else { + debug!("Rate limit period has passed for domain {}. Proceeding with certificate request.", config.opts.domain); + } + } else { + // Log the error message for debugging + tracing::debug!("Failed to parse retry-after from error message: {}", last_failure_msg); + // Can't parse retry-after from error message + // Try to extract duration from the error message if it contains ISO 8601 duration + // Look for patterns like "PT86225S" or "PT24H" anywhere in the message + let mut found_duration = None; + for word in last_failure_msg.split_whitespace() { + if word.starts_with("PT") { + // Try to parse as ISO 8601 duration + if let Some(dt) = parse_retry_after(&format!("retry after {}", word)) { + found_duration = Some(dt); + break; + } + } + } + + if let Some(retry_after) = found_duration { + let now = Utc::now(); + if now < retry_after { + let wait_duration = retry_after - now; + info!("Rate limit still active for domain {}: retry after {} ({} remaining). Skipping certificate request.", + config.opts.domain, retry_after, wait_duration); + return Ok(()); + } + } else { + // Can't parse retry-after at all, use exponential backoff (24 hours minimum for rate limits) + let rate_limit_cooldown = Duration::hours(24); + let now = Utc::now(); + if now - last_failure_time < rate_limit_cooldown { + let remaining = rate_limit_cooldown - (now - last_failure_time); + warn!("Rate limit error detected for domain {} (retry-after time not parseable from: '{}'). Waiting {} before retry. Skipping certificate request.", + config.opts.domain, last_failure_msg, remaining); + return Ok(()); + } else { + debug!("Rate limit cooldown period has passed for domain {}. Proceeding with certificate request.", config.opts.domain); + } + } + } + } + } + + let identifier = instant_acme::Identifier::Dns(domain.clone()); + let identifiers = vec![identifier]; + let mut order = match account + .new_order(&instant_acme::NewOrder::new(&identifiers)) + .await + { + Ok(order) => order, + Err(e) => { + let error_msg = format!("Failed to create new order for domain {}: {}", config.opts.domain, e); + warn!("{}. Skipping certificate request.", error_msg); + + // If it's a rate limit error, store it with retry-after time + let is_rate_limit = error_msg.contains("rateLimited") || + error_msg.contains("rate limit") || + error_msg.contains("too many certificates"); + + if is_rate_limit { + if let Some(retry_after) = parse_retry_after(&error_msg) { + info!("Rate limit error for domain {}: will retry after {}", config.opts.domain, retry_after); + } else { + warn!("Rate limit error for domain {} but could not parse retry-after time. Will wait 24 hours before retry.", config.opts.domain); + } + } + + if let Err(record_err) = storage.record_failure(&error_msg).await { + warn!("Failed to record failure: {}", record_err); + } + return Ok(()); + } + }; + + // Check if order is already ready (from a previous request) + let initial_state = order.state(); + + // Handle unexpected order status + if !matches!(initial_state.status, instant_acme::OrderStatus::Pending | instant_acme::OrderStatus::Ready) { + let error_msg = format!("Unexpected order status: {:?} for domain {}", initial_state.status, config.opts.domain); + warn!("{}. Skipping certificate request.", error_msg); + if let Err(record_err) = storage.record_failure(&error_msg).await { + warn!("Failed to record failure: {}", record_err); + } + return Ok(()); + } + + // If order is already Ready, skip challenge processing + let state = if matches!(initial_state.status, instant_acme::OrderStatus::Ready) { + info!("Order is already in Ready state, skipping challenge processing and proceeding to finalization"); + // Use initial_state as the final state since we're skipping challenge processing + initial_state + } else { + // Order is Pending, proceed with challenge processing + // Pick the desired challenge type and prepare the response. + let mut authorizations = order.authorizations(); + let mut challenges_set = Vec::new(); + + while let Some(result) = authorizations.next().await { + let mut authz = match result { + Ok(authz) => authz, + Err(e) => { + warn!("Failed to get authorization: {}. Skipping this authorization.", e); + continue; + } + }; + let domain = authz.identifier().to_string(); + + match authz.status { + instant_acme::AuthorizationStatus::Pending => {} + instant_acme::AuthorizationStatus::Valid => continue, + _ => todo!(), + } + + let mut challenge = match authz.challenge(challenge_type.clone()) { + Some(c) => c, + None => { + warn!("Domain '{}': No {:?} challenge found, skipping", domain, challenge_type); + continue; + } + }; + + let key_auth = challenge.key_authorization().as_str().to_string(); + match challenge_type { + instant_acme::ChallengeType::Http01 => { + // Check if existing challenge is expired and clean it up + let storage = StorageFactory::create_default(config)?; + let challenge_token = challenge.token.to_string(); + if let Ok(Some(_)) = storage.get_challenge_timestamp(&challenge_token).await { + // Challenge exists, check if expired + let max_ttl = config.opts.challenge_max_ttl_seconds.unwrap_or(3600); + if let Ok(true) = storage.is_challenge_expired(&challenge_token, max_ttl).await { + info!("Existing challenge for token {} is expired (TTL: {}s), will be replaced", challenge_token, max_ttl); + } + } + + if let Err(e) = cert_init_server(config, &challenge, &key_auth).await { + warn!("Failed to write challenge file for HTTP-01 challenge: {}. Skipping HTTP-01 challenge.", e); + continue; + } + } + instant_acme::ChallengeType::Dns01 => { + // For DNS-01 challenge, the TXT record should be at _acme-challenge.{base_domain} + // For wildcard domains (*.example.com), use the base domain (example.com) + // For non-wildcard domains, use the domain as-is + // Use the is_wildcard flag computed earlier, or check the domain from authorization + let base_domain = if domain.starts_with("*.") { + // Domain from authorization starts with *. - strip it + domain.strip_prefix("*.").unwrap_or(&domain) + } else if is_wildcard { + // is_wildcard is true but domain doesn't start with *. + // This can happen if ACME returns the base domain instead of wildcard + // Use the domain as-is (it's already the base domain) + &domain + } else { + // For non-wildcard, use domain as-is + &domain + }; + let dns_record = format!("_acme-challenge.{}", base_domain); + let dns_value = challenge.key_authorization().dns_value(); + + info!("DNS-01 challenge for domain '{}' (base domain: {}, wildcard: {}):", domain, base_domain, is_wildcard); + info!(" Create DNS TXT record: {} IN TXT {}", dns_record, dns_value); + info!(" This record must be added to your DNS provider before the challenge can be validated."); + + // Check if existing DNS challenge is expired and clean it up + let storage = StorageFactory::create_default(config)?; + if let Ok(Some(_)) = storage.get_dns_challenge_timestamp(&domain).await { + // DNS challenge exists, check if expired + let max_ttl = config.opts.challenge_max_ttl_seconds.unwrap_or(3600); + if let Ok(true) = storage.is_dns_challenge_expired(&domain, max_ttl).await { + info!("Existing DNS challenge for domain {} is expired (TTL: {}s), will be replaced", domain, max_ttl); + } + } + + // Save DNS challenge code to storage (Redis or file) + if let Err(e) = storage.write_dns_challenge(&domain, &dns_record, &dns_value).await { + warn!("Failed to save DNS challenge code to storage: {}", e); + } + + info!("Waiting for DNS record to propagate..."); + + // Automatically check DNS records + let max_attempts = config.opts.dns_lookup_max_attempts.unwrap_or(100); + let delay_seconds = config.opts.dns_lookup_delay_seconds.unwrap_or(10); + let dns_ready = check_dns_txt_record(&dns_record, &dns_value, max_attempts, delay_seconds).await; + + if !dns_ready { + let error_msg = format!("DNS record not found after checking for domain {}. Record: {} IN TXT {}", domain, dns_record, dns_value); + warn!("{}. Please verify the DNS record is set correctly.", error_msg); + let storage = StorageFactory::create_default(config)?; + if let Err(record_err) = storage.record_failure(&error_msg).await { + warn!("Failed to record failure: {}", record_err); + } + return Ok(()); + } + + info!("DNS record found! Proceeding with challenge validation..."); + } + instant_acme::ChallengeType::TlsAlpn01 => todo!("TLS-ALPN-01 is not supported"), + _ => { + let error_msg = format!("Unsupported challenge type: {:?}", challenge_type); + warn!("{}", error_msg); + let storage = StorageFactory::create_default(config)?; + if let Err(record_err) = storage.record_failure(&error_msg).await { + warn!("Failed to record failure: {}", record_err); + } + return Ok(()); + } + } + + // Notify ACME server to validate + info!("Domain '{}': Notifying ACME server to validate challenge", domain); + challenge.set_ready().await + .with_context(|| format!("Failed to set challenge ready for domain {}", domain))?; + challenges_set.push(domain); + } + + if challenges_set.is_empty() { + let error_msg = format!("All domains failed challenge setup for domain {}", config.opts.domain); + warn!("{}", error_msg); + let storage = StorageFactory::create_default(config)?; + if let Err(record_err) = storage.record_failure(&error_msg).await { + warn!("Failed to record failure: {}", record_err); + } + return Ok(()); + } + + // Exponentially back off until the order becomes ready or invalid. + let mut tries = 0u8; + let state = loop { + let state = match order.refresh().await { + Ok(s) => s, + Err(e) => { + if tries >= 10 { + let error_msg = format!("Order refresh failed after {} attempts: {}", tries, e); + warn!("{}", error_msg); + let storage = StorageFactory::create_default(config)?; + if let Err(record_err) = storage.record_failure(&error_msg).await { + warn!("Failed to record failure: {}", record_err); + } + return Ok(()); + } + tries += 1; + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + continue; + } + }; + + info!("Order state: {:#?}", state); + if let OrderStatus::Ready | OrderStatus::Invalid | OrderStatus::Valid = state.status { + break state; + } + + tries += 1; + if tries >= 10 { + let error_msg = format!("Giving up: order is not ready after {} attempts for domain {}", tries, config.opts.domain); + warn!("{}", error_msg); + let storage = StorageFactory::create_default(config)?; + if let Err(record_err) = storage.record_failure(&error_msg).await { + warn!("Failed to record failure: {}", record_err); + } + return Ok(()); + } + + let delay = std::time::Duration::from_secs(2 + tries as u64); + info!("order is not ready, waiting {delay:?}"); + tokio::time::sleep(delay).await; + }; + + if state.status == OrderStatus::Invalid { + // Try to get more details about why the order is invalid + let mut error_details = Vec::new(); + if let Some(error) = &state.error { + error_details.push(format!("Order error: {:?}", error)); + } + + // Fetch authorization details from ACME server if state is None + for auth in &state.authorizations { + if let Some(auth_state) = &auth.state { + // Check authorization status for more details + match &auth_state.status { + instant_acme::AuthorizationStatus::Invalid => { + error_details.push(format!("Authorization {} is invalid", auth.url)); + } + instant_acme::AuthorizationStatus::Expired => { + error_details.push(format!("Authorization {} expired", auth.url)); + } + instant_acme::AuthorizationStatus::Revoked => { + error_details.push(format!("Authorization {} revoked", auth.url)); + } + _ => {} + } + } else { + // Authorization state is None - this means the authorization details weren't included in the order state + // We can't fetch it again because order.authorizations() was already consumed + // Log the URL so the user can check it manually + warn!("Authorization state is None for {}. This usually means the authorization failed or expired. Check the authorization URL for details.", auth.url); + error_details.push(format!("Authorization {} state unavailable (check URL for details)", auth.url)); + } + } + + let error_msg = if error_details.is_empty() { + format!("Order is invalid but no error details available. Order state: {:#?}", state) + } else { + format!("Order is invalid. Details: {}", error_details.join("; ")) + }; + warn!("{}", error_msg); + let storage = StorageFactory::create_default(config)?; + if let Err(record_err) = storage.record_failure(&error_msg).await { + warn!("Failed to record failure: {}", record_err); + } + return Ok(()); + } + + state + }; + + // Check if state is invalid before proceeding to finalization + if state.status == OrderStatus::Invalid { + let error_msg = format!("Order is invalid for domain {}", config.opts.domain); + warn!("{}", error_msg); + let storage = StorageFactory::create_default(config)?; + if let Err(record_err) = storage.record_failure(&error_msg).await { + warn!("Failed to record failure: {}", record_err); + } + return Ok(()); + } + + // If the order is ready, we can provision the certificate. + // Finalize the order - this will generate a CSR and return the private key PEM. + let private_key_pem = order.finalize().await + .context("Failed to finalize ACME order")?; + + std::thread::sleep(std::time::Duration::from_secs(1)); + let mut tries = 1u8; + + let cert_chain_pem = loop { + match order.certificate().await { + Ok(Some(cert_chain_pem)) => { + info!("Certificate ready!"); + break cert_chain_pem; + } + Ok(None) => { + if tries > 10 { + let error_msg = format!("Giving up: certificate is still not ready after {} attempts", tries); + let storage = StorageFactory::create_default(config)?; + if let Err(record_err) = storage.record_failure(&error_msg).await { + warn!("Failed to record failure: {}", record_err); + } + return Err(anyhow!("{}", error_msg)); + } + tries += 1; + info!("Certificate not ready yet..."); + continue; + } + Err(e) => { + let error_msg = format!("Error getting certificate: {}", e); + let storage = StorageFactory::create_default(config)?; + if let Err(record_err) = storage.record_failure(&error_msg).await { + warn!("Failed to record failure: {}", record_err); + } + return Err(anyhow!("{}", error_msg)); + } + } + }; + + write_certs(config, cert_chain_pem, private_key_pem).await + .context("Failed to write certificates to storage")?; + + // Clear any previous failure records since certificate was successfully generated + let storage = StorageFactory::create_default(config)?; + if let Err(clear_err) = storage.clear_failure().await { + warn!("Failed to clear failure record: {}", clear_err); + } + + info!("HTTPS TLS Cert init successful! Certificate written to storage."); + + Ok(()) +} + +async fn write_certs( + config: &Config, + cert_chain_pem: String, + private_key_pem: String, +) -> AtomicServerResult<()> { + // Always use Redis storage (storage_type option is kept for compatibility but always uses Redis) + info!("Creating Redis storage backend"); + let storage = StorageFactory::create_default(config)?; + info!("Storage backend created successfully"); + + info!("Writing TLS certificates to storage (certbot-style)"); + + // Parse the certificate chain to separate cert from chain + // The cert_chain_pem contains the domain cert first, followed by intermediate certs + // It's already in PEM format, so we split it by "-----BEGIN CERTIFICATE-----" + let cert_parts: Vec = cert_chain_pem + .split("-----BEGIN CERTIFICATE-----") + .filter(|s| !s.trim().is_empty()) + .map(|s| format!("-----BEGIN CERTIFICATE-----{}", s)) + .collect(); + + if cert_parts.is_empty() { + return Err(anyhow!("No certificates found in chain")); + } + + // First certificate is the domain certificate + let domain_cert_pem = cert_parts[0].trim().to_string(); + + // Remaining certificates form the chain + let chain_pem = if cert_parts.len() > 1 { + cert_parts[1..].join("\n") + } else { + String::new() + }; + + // Combine cert and chain to create fullchain + let mut fullchain = domain_cert_pem.clone(); + if !chain_pem.is_empty() { + fullchain.push_str("\n"); + fullchain.push_str(&chain_pem); + } + + info!("Writing certificate to Redis storage backend..."); + storage.write_certs( + domain_cert_pem.as_bytes(), + chain_pem.as_bytes(), + private_key_pem.as_bytes(), + ).await + .context("Failed to write certificates to storage backend")?; + info!("Certificates written successfully to Redis storage backend"); + + storage.write_created_at(chrono::Utc::now()).await + .context("Failed to write created_at timestamp")?; + + // Save certificates to proxy_certificates path + if let Some(proxy_certificates_path) = get_proxy_certificates_path() { + if let Err(e) = save_cert_to_proxy_path( + &config.opts.domain, + &fullchain, + &private_key_pem, + &proxy_certificates_path, + ).await { + warn!("Failed to save certificate to proxy_certificates path: {}", e); + } else { + info!("Certificate saved to proxy_certificates path: {}", proxy_certificates_path); + } + } else { + warn!("proxy_certificates path not configured, skipping file save"); + } + info!("Created_at timestamp written successfully"); + + Ok(()) +} + +/// Start HTTP server for ACME challenge requests +/// This server only serves ACME challenge files and keeps running indefinitely +pub async fn start_http_server(app_config: &AppConfig) -> AtomicServerResult<()> { + let address = format!("{}:{}", app_config.server.ip, app_config.server.port); + info!("Starting HTTP server for ACME challenges at {}", address); + info!("Server will only accept ACME challenge requests at /.well-known/acme-challenge/*"); + info!("Certificate expiration check endpoints:"); + info!(" - GET /cert/expiration - Check all domains"); + info!(" - GET /cert/expiration/{{domain}} - Check specific domain"); + info!("To stop the program, press Ctrl+C"); + + // Use the base storage path for serving ACME challenges + // Challenges are stored in a shared location: https_path/well-known/acme-challenge/ + let base_static_path = std::path::PathBuf::from(&app_config.storage.https_path); + + // Build the path to the well-known/acme-challenge directory + // Files are stored at: base_path/well-known/acme-challenge/{token} + let mut challenge_static_path = base_static_path.clone(); + challenge_static_path.push("well-known"); + challenge_static_path.push("acme-challenge"); + + // Ensure the challenge directory exists (required for actix_files::Files) + // Even when using Redis storage, challenge files are still written to filesystem for HTTP-01 + tokio::fs::create_dir_all(&challenge_static_path) + .await + .with_context(|| format!("Failed to create challenge static path directory: {:?}", challenge_static_path))?; + + let base_https_path = base_static_path.clone(); + let app_config_data = web::Data::new(app_config.clone()); + let base_path_data = web::Data::new(base_https_path); + + // Create HTTP server that only serves ACME challenge files + // The server will serve from any domain's challenge directory + let server = HttpServer::new(move || { + App::new() + .app_data(app_config_data.clone()) + .app_data(base_path_data.clone()) + .service( + // Serve ACME challenges from the challenge directory + // URL: /.well-known/acme-challenge/{token} + // File: base_path/well-known/acme-challenge/{token} + // The Files service maps the URL path to the file system path + actix_files::Files::new("/.well-known/acme-challenge", challenge_static_path.clone()) + .prefer_utf8(true), + ) + .route( + "/cert/expiration", + web::get().to(check_all_certs_expiration_handler), + ) + .route( + "/cert/expiration/{domain}", + web::get().to(check_cert_expiration_handler), + ) + // Reject all other requests with 404 + .default_service(web::route().to(|| async { + HttpResponse::NotFound().body("Not Found") + })) + }) + .bind(&address) + .with_context(|| format!("Failed to bind HTTP server to {}", address))?; + + info!("HTTP server started successfully at {}", address); + + // Keep the server running indefinitely + server.run().await + .with_context(|| "HTTP server error")?; + + Ok(()) +} diff --git a/src/acme/mod.rs b/src/acme/mod.rs index 3ae4041..ef20ff4 100644 --- a/src/acme/mod.rs +++ b/src/acme/mod.rs @@ -1,19 +1,19 @@ -//! ACME certificate management module -//! Re-exports from lib.rs for use in main application - -mod errors; -pub mod config; -mod storage; -mod domain_reader; -pub mod upstreams_reader; -pub mod embedded; -mod lib; - -pub use errors::AtomicServerResult; -pub use config::{Config, ConfigOpts, AppConfig, RetryConfig, RedisSslConfig}; -pub use storage::{Storage, StorageFactory, StorageType}; -pub use domain_reader::{DomainConfig, DomainReader, DomainReaderFactory}; -pub use upstreams_reader::{UpstreamsDomainReader, UpstreamsAcmeConfig}; -pub use embedded::{EmbeddedAcmeServer, EmbeddedAcmeConfig}; -pub use lib::*; - +//! ACME certificate management module +//! Re-exports from lib.rs for use in main application + +mod errors; +pub mod config; +mod storage; +mod domain_reader; +pub mod upstreams_reader; +pub mod embedded; +mod lib; + +pub use errors::AtomicServerResult; +pub use config::{Config, ConfigOpts, AppConfig, RetryConfig, RedisSslConfig}; +pub use storage::{Storage, StorageFactory, StorageType}; +pub use domain_reader::{DomainConfig, DomainReader, DomainReaderFactory}; +pub use upstreams_reader::{UpstreamsDomainReader, UpstreamsAcmeConfig}; +pub use embedded::{EmbeddedAcmeServer, EmbeddedAcmeConfig}; +pub use lib::*; + diff --git a/src/acme/storage/mod.rs b/src/acme/storage/mod.rs index b3f92a1..b3abd04 100644 --- a/src/acme/storage/mod.rs +++ b/src/acme/storage/mod.rs @@ -1,136 +1,136 @@ -//! Storage backend abstraction for certificate storage. -//! Uses Redis as the only storage backend. - -mod redis; - -use anyhow::Result; -use async_trait::async_trait; -use std::path::PathBuf; - -pub use redis::RedisStorage; - -/// Storage backend type (Redis only) -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StorageType { - /// Redis storage - Redis, -} - -/// Trait for certificate storage backends -#[async_trait] -pub trait Storage: Send + Sync { - /// Read certificate from storage (from live directory) - async fn read_cert(&self) -> Result>; - - /// Read certificate chain from storage (from live directory) - async fn read_chain(&self) -> Result>; - - /// Read fullchain (cert + chain) from storage (from live directory) - async fn read_fullchain(&self) -> Result>; - - /// Read private key from storage (from live directory) - async fn read_key(&self) -> Result>; - - /// Write certificate, chain, and fullchain to storage (certbot-style) - /// cert: The domain certificate - /// chain: The intermediate certificate chain - /// key: The private key - async fn write_certs(&self, cert: &[u8], chain: &[u8], key: &[u8]) -> Result<()>; - - /// Get the SHA256 hash of the certificate (fullchain + key combined) - /// Returns None if certificate doesn't exist - async fn get_certificate_hash(&self) -> Result>; - - /// Check if certificate exists - async fn cert_exists(&self) -> bool; - - /// Read certificate creation timestamp - async fn read_created_at(&self) -> Result>; - - /// Write certificate creation timestamp - async fn write_created_at(&self, created_at: chrono::DateTime) -> Result<()>; - - /// Write challenge file for ACME HTTP-01 challenge - async fn write_challenge(&self, token: &str, key_auth: &str) -> Result<()>; - - /// Write DNS challenge code for ACME DNS-01 challenge - async fn write_dns_challenge(&self, domain: &str, dns_record: &str, dns_value: &str) -> Result<()>; - - /// Get the timestamp when a challenge was created - /// Returns None if challenge doesn't exist or has no timestamp - async fn get_challenge_timestamp(&self, token: &str) -> Result>>; - - /// Get the timestamp when a DNS challenge was created - /// Returns None if challenge doesn't exist or has no timestamp - async fn get_dns_challenge_timestamp(&self, domain: &str) -> Result>>; - - /// Check if a challenge is expired based on TTL - async fn is_challenge_expired(&self, token: &str, max_ttl_seconds: u64) -> Result; - - /// Check if a DNS challenge is expired based on TTL - async fn is_dns_challenge_expired(&self, domain: &str, max_ttl_seconds: u64) -> Result; - - /// Clean up expired challenges - async fn cleanup_expired_challenges(&self, max_ttl_seconds: u64) -> Result<()>; - - /// Get the path for static files (well-known directory) - fn static_path(&self) -> PathBuf; - - /// Read fullchain synchronously (for compatibility with sync APIs like rustls) - /// Returns None if the storage backend doesn't support sync operations - fn read_fullchain_sync(&self) -> Option>> { - None - } - - /// Read private key synchronously (for compatibility with sync APIs like rustls) - /// Returns None if the storage backend doesn't support sync operations - fn read_key_sync(&self) -> Option>> { - None - } - - /// Read ACME account credentials from storage - /// Returns None if credentials don't exist - async fn read_account_credentials(&self) -> Result>; - - /// Write ACME account credentials to storage - async fn write_account_credentials(&self, credentials: &str) -> Result<()>; - - /// Record a certificate generation failure - async fn record_failure(&self, error: &str) -> Result<()>; - - /// Get the last failure timestamp and error message - /// Returns None if no failure was recorded - async fn get_last_failure(&self) -> Result, String)>>; - - /// Clear failure record (called when certificate is successfully generated) - async fn clear_failure(&self) -> Result<()>; - - /// Get the number of consecutive failures - async fn get_failure_count(&self) -> Result; -} - -/// Factory for creating storage backends -pub struct StorageFactory; - -impl StorageFactory { - /// Create a storage backend (Redis only) - pub fn create(storage_type: StorageType, config: &crate::acme::Config) -> Result> { - match storage_type { - StorageType::Redis => { - Ok(Box::new(RedisStorage::new(config)?)) - } - } - } - - /// Create storage from AppConfig storage settings (Redis only) - pub fn create_from_app_config(_app_config: &crate::acme::AppConfig, domain_config: &crate::acme::Config) -> Result> { - Self::create(StorageType::Redis, domain_config) - } - - /// Create storage based on config settings (Redis only) - pub fn create_default(config: &crate::acme::Config) -> Result> { - tracing::debug!("Creating Redis storage backend for domain: {}", config.opts.domain); - Self::create(StorageType::Redis, config) - } -} - +//! Storage backend abstraction for certificate storage. +//! Uses Redis as the only storage backend. + +mod redis; + +use anyhow::Result; +use async_trait::async_trait; +use std::path::PathBuf; + +pub use redis::RedisStorage; + +/// Storage backend type (Redis only) +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StorageType { + /// Redis storage + Redis, +} + +/// Trait for certificate storage backends +#[async_trait] +pub trait Storage: Send + Sync { + /// Read certificate from storage (from live directory) + async fn read_cert(&self) -> Result>; + + /// Read certificate chain from storage (from live directory) + async fn read_chain(&self) -> Result>; + + /// Read fullchain (cert + chain) from storage (from live directory) + async fn read_fullchain(&self) -> Result>; + + /// Read private key from storage (from live directory) + async fn read_key(&self) -> Result>; + + /// Write certificate, chain, and fullchain to storage (certbot-style) + /// cert: The domain certificate + /// chain: The intermediate certificate chain + /// key: The private key + async fn write_certs(&self, cert: &[u8], chain: &[u8], key: &[u8]) -> Result<()>; + + /// Get the SHA256 hash of the certificate (fullchain + key combined) + /// Returns None if certificate doesn't exist + async fn get_certificate_hash(&self) -> Result>; + + /// Check if certificate exists + async fn cert_exists(&self) -> bool; + + /// Read certificate creation timestamp + async fn read_created_at(&self) -> Result>; + + /// Write certificate creation timestamp + async fn write_created_at(&self, created_at: chrono::DateTime) -> Result<()>; + + /// Write challenge file for ACME HTTP-01 challenge + async fn write_challenge(&self, token: &str, key_auth: &str) -> Result<()>; + + /// Write DNS challenge code for ACME DNS-01 challenge + async fn write_dns_challenge(&self, domain: &str, dns_record: &str, dns_value: &str) -> Result<()>; + + /// Get the timestamp when a challenge was created + /// Returns None if challenge doesn't exist or has no timestamp + async fn get_challenge_timestamp(&self, token: &str) -> Result>>; + + /// Get the timestamp when a DNS challenge was created + /// Returns None if challenge doesn't exist or has no timestamp + async fn get_dns_challenge_timestamp(&self, domain: &str) -> Result>>; + + /// Check if a challenge is expired based on TTL + async fn is_challenge_expired(&self, token: &str, max_ttl_seconds: u64) -> Result; + + /// Check if a DNS challenge is expired based on TTL + async fn is_dns_challenge_expired(&self, domain: &str, max_ttl_seconds: u64) -> Result; + + /// Clean up expired challenges + async fn cleanup_expired_challenges(&self, max_ttl_seconds: u64) -> Result<()>; + + /// Get the path for static files (well-known directory) + fn static_path(&self) -> PathBuf; + + /// Read fullchain synchronously (for compatibility with sync APIs like rustls) + /// Returns None if the storage backend doesn't support sync operations + fn read_fullchain_sync(&self) -> Option>> { + None + } + + /// Read private key synchronously (for compatibility with sync APIs like rustls) + /// Returns None if the storage backend doesn't support sync operations + fn read_key_sync(&self) -> Option>> { + None + } + + /// Read ACME account credentials from storage + /// Returns None if credentials don't exist + async fn read_account_credentials(&self) -> Result>; + + /// Write ACME account credentials to storage + async fn write_account_credentials(&self, credentials: &str) -> Result<()>; + + /// Record a certificate generation failure + async fn record_failure(&self, error: &str) -> Result<()>; + + /// Get the last failure timestamp and error message + /// Returns None if no failure was recorded + async fn get_last_failure(&self) -> Result, String)>>; + + /// Clear failure record (called when certificate is successfully generated) + async fn clear_failure(&self) -> Result<()>; + + /// Get the number of consecutive failures + async fn get_failure_count(&self) -> Result; +} + +/// Factory for creating storage backends +pub struct StorageFactory; + +impl StorageFactory { + /// Create a storage backend (Redis only) + pub fn create(storage_type: StorageType, config: &crate::acme::Config) -> Result> { + match storage_type { + StorageType::Redis => { + Ok(Box::new(RedisStorage::new(config)?)) + } + } + } + + /// Create storage from AppConfig storage settings (Redis only) + pub fn create_from_app_config(_app_config: &crate::acme::AppConfig, domain_config: &crate::acme::Config) -> Result> { + Self::create(StorageType::Redis, domain_config) + } + + /// Create storage based on config settings (Redis only) + pub fn create_default(config: &crate::acme::Config) -> Result> { + tracing::debug!("Creating Redis storage backend for domain: {}", config.opts.domain); + Self::create(StorageType::Redis, config) + } +} + diff --git a/src/acme/storage/redis.rs b/src/acme/storage/redis.rs index e687437..4f3e0b7 100644 --- a/src/acme/storage/redis.rs +++ b/src/acme/storage/redis.rs @@ -1,689 +1,689 @@ -//! Redis storage backend implementation - -use crate::acme::Config; -use super::Storage; -use anyhow::{anyhow, Context, Result}; -use async_trait::async_trait; -use redis::AsyncCommands; -use std::path::PathBuf; -use std::sync::Arc; - -/// Redis storage backend -pub struct RedisStorage { - client: Arc, - base_key: String, - static_path: PathBuf, - // Optional shared connection manager from RedisManager (for connection pooling) - // ConnectionManager is Clone, so we can store a clone of it - shared_connection: Option, -} - -impl RedisStorage { - /// Create a new Redis storage backend - pub fn new(config: &Config) -> Result { - // Get Redis URL from config, environment, or use default - let redis_url = config.opts.redis_url.clone() - .or_else(|| std::env::var("REDIS_URL").ok()) - .unwrap_or_else(|| "redis://127.0.0.1:6379".to_string()); - - // Create Redis client (ConnectionManager will be reused from RedisManager if available) - let client = if let Some(ssl_config) = &config.opts.redis_ssl { - Self::create_client_with_ssl(&redis_url, ssl_config)? - } else { - redis::Client::open(redis_url.as_str()) - .with_context(|| format!("Failed to connect to Redis at {}", redis_url))? - }; - - // Get Redis prefix from RedisManager if available, otherwise use default - let prefix = crate::redis::RedisManager::get() - .map(|rm| rm.get_prefix().to_string()) - .unwrap_or_else(|_| "ssl-storage".to_string()); - - // Normalize domain: strip wildcard prefix (*.) for consistent Redis key naming - // This ensures wildcard certificates (*.example.com) are stored with the same key - // as when they're looked up (example.com) - let domain = config.opts.domain.clone(); - let normalized_domain = domain.strip_prefix("*.").unwrap_or(&domain); - let base_key = format!("{}:{}", prefix, normalized_domain); - - // Try to reuse the connection manager from RedisManager for connection pooling - // ConnectionManager is Clone and shares the underlying connection pool - let shared_connection = crate::redis::RedisManager::get() - .ok() - .map(|rm| rm.get_connection()); - - Ok(Self { - client: Arc::new(client), - base_key, - static_path: config.static_path.clone(), - shared_connection, - }) - } - - /// Create Redis client with custom SSL/TLS configuration - fn create_client_with_ssl(redis_url: &str, ssl_config: &crate::acme::config::RedisSslConfig) -> Result { - use native_tls::{Certificate, Identity, TlsConnector}; - - // Build TLS connector with custom certificates - let mut tls_builder = TlsConnector::builder(); - - // Load CA certificate if provided - if let Some(ca_cert_path) = &ssl_config.ca_cert_path { - let ca_cert_data = std::fs::read(ca_cert_path) - .with_context(|| format!("Failed to read CA certificate from {}", ca_cert_path))?; - let ca_cert = Certificate::from_pem(&ca_cert_data) - .with_context(|| format!("Failed to parse CA certificate from {}", ca_cert_path))?; - tls_builder.add_root_certificate(ca_cert); - tracing::info!("Loaded CA certificate from {}", ca_cert_path); - } - - // Load client certificate and key if provided - if let (Some(client_cert_path), Some(client_key_path)) = (&ssl_config.client_cert_path, &ssl_config.client_key_path) { - let client_cert_data = std::fs::read(client_cert_path) - .with_context(|| format!("Failed to read client certificate from {}", client_cert_path))?; - let client_key_data = std::fs::read(client_key_path) - .with_context(|| format!("Failed to read client key from {}", client_key_path))?; - - // Try to create identity from PEM format (cert + key) - let identity = Identity::from_pkcs8(&client_cert_data, &client_key_data) - .or_else(|_| { - // Try PEM format if PKCS#8 fails - Identity::from_pkcs12(&client_cert_data, "") - }) - .or_else(|_| { - // Try loading as separate PEM files - // Combine cert and key into a single PEM - let mut combined = client_cert_data.clone(); - combined.extend_from_slice(b"\n"); - combined.extend_from_slice(&client_key_data); - Identity::from_pkcs12(&combined, "") - }) - .with_context(|| format!("Failed to parse client certificate/key from {} and {}. Supported formats: PKCS#8, PKCS#12, or PEM", client_cert_path, client_key_path))?; - tls_builder.identity(identity); - tracing::info!("Loaded client certificate from {} and key from {}", client_cert_path, client_key_path); - } - - // Configure certificate verification - if ssl_config.insecure { - tls_builder.danger_accept_invalid_certs(true); - tls_builder.danger_accept_invalid_hostnames(true); - tracing::warn!("Redis SSL: Certificate verification disabled (insecure mode)"); - } - - let _tls_connector = tls_builder.build() - .with_context(|| "Failed to build TLS connector")?; - - // Note: The redis crate with tokio-native-tls-comp uses native-tls internally, - // but doesn't expose a way to pass a custom TlsConnector. However, when using - // rediss:// URLs, it will use the system trust store. For custom CA certificates, - // we need to add them to the system trust store or use a workaround. - // - // For now, we'll create the client normally. The TLS configuration above - // validates the certificates, but the redis crate will use its own TLS setup. - // - // TODO: The redis crate doesn't support custom TlsConnector directly. - // We may need to: - // 1. Add CA cert to system trust store (requires system-level changes) - // 2. Use a different Redis client that supports custom TLS - // 3. Wait for redis crate to support custom TLS configuration - - // For insecure mode, the redis crate should handle it via rediss:// URL - // For custom CA certs, we'll need to rely on the system trust store - // or use environment variables if the redis crate supports it - - let client = redis::Client::open(redis_url) - .with_context(|| format!("Failed to create Redis client with SSL config"))?; - - Ok(client) - } - - /// Get Redis connection (reuses pooled connection if available) - async fn get_conn(&self) -> Result { - use redis::aio::ConnectionManager; - - // Prefer shared connection from RedisManager (connection pooling) - // ConnectionManager is Clone and shares the underlying connection pool - if let Some(shared_conn) = &self.shared_connection { - return Ok(shared_conn.clone()); - } - - // Fallback: create new connection manager from client - // This still uses connection pooling internally via the client - let client = Arc::as_ref(&self.client); - ConnectionManager::new(client.clone()) - .await - .with_context(|| "Failed to get Redis connection") - } - - /// Get key for live certificate - fn live_key(&self, file_type: &str) -> String { - format!("{}:live:{}", self.base_key, file_type) - } - - - /// Get metadata key - fn metadata_key(&self, key: &str) -> String { - format!("{}:metadata:{}", self.base_key, key) - } - - /// Get key for challenge token - fn challenge_key(&self, token: &str) -> String { - format!("{}:challenge:{}", self.base_key, token) - } - - /// Get key for DNS challenge - fn dns_challenge_key(&self) -> String { - format!("{}:dns-challenge", self.base_key) - } - - /// Get key for distributed lock - fn lock_key(&self) -> String { - format!("{}:lock", self.base_key) - } - - /// Acquire a distributed lock for this domain - /// Returns true if lock was acquired, false if already locked - /// Lock expires after `ttl_seconds` to prevent deadlocks - pub async fn acquire_lock(&self, ttl_seconds: u64) -> Result { - let mut conn = self.get_conn().await?; - let lock_key = self.lock_key(); - - // Use SET with NX (only set if not exists) and EX (expiration) for atomic lock acquisition - let result: Option<()> = redis::cmd("SET") - .arg(&lock_key) - .arg("locked") - .arg("NX") // Only set if key doesn't exist - .arg("EX") // Set expiration - .arg(ttl_seconds) - .query_async(&mut conn) - .await - .with_context(|| format!("Failed to acquire lock for key: {}", lock_key))?; - - Ok(result.is_some()) - } - - /// Release the distributed lock for this domain - pub async fn release_lock(&self) -> Result<()> { - let mut conn = self.get_conn().await?; - let lock_key = self.lock_key(); - conn.del::<_, ()>(&lock_key).await - .with_context(|| format!("Failed to release lock for key: {}", lock_key))?; - Ok(()) - } - - /// Execute a function with a distributed lock - /// Returns Ok with default value if the lock cannot be acquired (skips operation) - pub async fn with_lock(&self, ttl_seconds: u64, f: F) -> Result - where - F: FnOnce() -> Fut, - Fut: std::future::Future>, - T: Default, - { - if !self.acquire_lock(ttl_seconds).await? { - tracing::warn!("Failed to acquire lock for domain - another instance is processing this domain. Skipping operation."); - return Ok(Default::default()); - } - - let result = f().await; - - // Always try to release the lock, even if f() returned an error - if let Err(e) = self.release_lock().await { - tracing::warn!("Failed to release lock: {}", e); - } - - result - } - - /// Delete all archived certificates (Redis doesn't need to keep old versions) - async fn delete_archived_certs(&self, conn: &mut redis::aio::ConnectionManager) -> Result<()> { - // Get all keys matching the archive pattern - let archive_pattern = format!("{}:archive:*", self.base_key); - let keys: Vec = conn.keys(&archive_pattern).await - .with_context(|| format!("Failed to get archive keys matching {}", archive_pattern))?; - - // Delete all archived keys - if !keys.is_empty() { - conn.del::<_, ()>(keys).await - .with_context(|| "Failed to delete archived certificates")?; - } - - Ok(()) - } -} - -#[async_trait] -impl Storage for RedisStorage { - async fn read_cert(&self) -> Result> { - let mut conn = self.get_conn().await?; - let key = self.live_key("cert"); - let data: Vec = conn.get(&key).await - .with_context(|| format!("Failed to read certificate from Redis key: {}", key))?; - Ok(data) - } - - async fn read_chain(&self) -> Result> { - let mut conn = self.get_conn().await?; - let key = self.live_key("chain"); - let data: Vec = conn.get(&key).await - .with_context(|| format!("Failed to read chain from Redis key: {}", key))?; - Ok(data) - } - - async fn read_fullchain(&self) -> Result> { - let mut conn = self.get_conn().await?; - let key = self.live_key("fullchain"); - let data: Vec = conn.get(&key).await - .with_context(|| format!("Failed to read fullchain from Redis key: {}", key))?; - Ok(data) - } - - async fn read_key(&self) -> Result> { - let mut conn = self.get_conn().await?; - let key = self.live_key("privkey"); - let data: Vec = conn.get(&key).await - .with_context(|| format!("Failed to read private key from Redis key: {}", key))?; - Ok(data) - } - - async fn write_certs(&self, cert: &[u8], chain: &[u8], key: &[u8]) -> Result<()> { - tracing::debug!("Connecting to Redis for certificate storage..."); - let mut conn = self.get_conn().await - .with_context(|| "Failed to get Redis connection")?; - tracing::debug!("Redis connection established"); - - // Combine cert and chain to create fullchain - let mut fullchain = cert.to_vec(); - fullchain.extend_from_slice(chain); - tracing::debug!("Combined certificate chain (cert: {} bytes, chain: {} bytes, fullchain: {} bytes)", - cert.len(), chain.len(), fullchain.len()); - - // Calculate SHA256 hash of fullchain + key for change detection - use sha2::{Sha256, Digest}; - let mut hasher = Sha256::new(); - hasher.update(&fullchain); - hasher.update(key); - let hash = format!("{:x}", hasher.finalize()); - tracing::debug!("Calculated certificate hash: {}", hash); - - // Delete old archived certificates (keep only live/current version) - self.delete_archived_certs(&mut conn).await?; - - // Update live keys (current version only, no archive in Redis) - let live_files = [ - ("cert", cert), - ("chain", chain), - ("fullchain", &fullchain), - ("privkey", key), - ]; - - for (file_type, content) in live_files.iter() { - let live_key = self.live_key(file_type); - tracing::debug!("Writing {} to Redis key: {} ({} bytes)", file_type, live_key, content.len()); - conn.set::<_, _, ()>(&live_key, content).await - .with_context(|| format!("Failed to write live {} to Redis key: {}", file_type, live_key))?; - tracing::debug!("Successfully wrote {} to Redis key: {}", file_type, live_key); - } - - // Store certificate hash for change detection - let hash_key = self.metadata_key("certificate_hash"); - conn.set::<_, _, ()>(&hash_key, &hash).await - .with_context(|| format!("Failed to write certificate hash to Redis key: {}", hash_key))?; - tracing::debug!("Stored certificate hash: {} at key: {}", hash, hash_key); - - tracing::info!("All certificates written successfully to Redis for domain: {}", self.base_key); - Ok(()) - } - - async fn cert_exists(&self) -> bool { - let mut conn = match self.get_conn().await { - Ok(c) => c, - Err(_) => return false, - }; - - let key = self.live_key("cert"); - conn.exists(&key).await.unwrap_or(false) - } - - async fn read_created_at(&self) -> Result> { - let mut conn = self.get_conn().await?; - let key = self.metadata_key("created_at"); - let content: String = conn.get(&key).await - .with_context(|| format!("Failed to read created_at from Redis key: {}", key))?; - content - .parse::>() - .with_context(|| format!("Failed to parse created_at: {}", content)) - } - - async fn write_created_at(&self, created_at: chrono::DateTime) -> Result<()> { - let mut conn = self.get_conn().await?; - let key = self.metadata_key("created_at"); - conn.set::<_, _, ()>(&key, created_at.to_string()).await - .with_context(|| format!("Failed to write created_at to Redis key: {}", key))?; - Ok(()) - } - - async fn write_challenge(&self, token: &str, key_auth: &str) -> Result<()> { - // Store challenge token in Redis - let mut conn = self.get_conn().await?; - let challenge_key = self.challenge_key(token); - conn.set::<_, _, ()>(&challenge_key, key_auth).await - .with_context(|| format!("Failed to write challenge token to Redis key: {}", challenge_key))?; - - // Store challenge timestamp - let timestamp = chrono::Utc::now(); - let timestamp_key = format!("{}:timestamp", challenge_key); - conn.set::<_, _, ()>(×tamp_key, timestamp.to_rfc3339()).await - .with_context(|| format!("Failed to write challenge timestamp to Redis key: {}", timestamp_key))?; - - // Also write to filesystem for HTTP-01 challenge serving - // The HTTP server needs to serve these files - // Write challenge files to a shared location (not per-domain) - // This allows the HTTP server to serve them from a single base path - let base_path = self.static_path.parent() - .ok_or_else(|| anyhow!("Cannot get parent path from static_path"))? - .to_path_buf(); - - let mut well_known_folder = base_path.clone(); - well_known_folder.push("well-known"); - tokio::fs::create_dir_all(&well_known_folder) - .await - .with_context(|| format!("Failed to create well-known directory {:?}", well_known_folder))?; - - let mut challenge_path = well_known_folder.clone(); - challenge_path.push("acme-challenge"); - tokio::fs::create_dir_all(&challenge_path) - .await - .with_context(|| format!("Failed to create acme-challenge directory {:?}", challenge_path))?; - - challenge_path.push(token); - tokio::fs::write(&challenge_path, key_auth) - .await - .with_context(|| format!("Failed to write challenge file {:?}", challenge_path))?; - - Ok(()) - } - - fn static_path(&self) -> PathBuf { - self.static_path.clone() - } - - fn read_fullchain_sync(&self) -> Option>> { - // Redis doesn't support sync operations easily, return None - None - } - - fn read_key_sync(&self) -> Option>> { - // Redis doesn't support sync operations easily, return None - None - } - - async fn write_dns_challenge(&self, _domain: &str, dns_record: &str, dns_value: &str) -> Result<()> { - // Store DNS challenge code in Redis - let mut conn = self.get_conn().await?; - let dns_key = self.dns_challenge_key(); - - // Store as JSON with dns_record and challenge_code - let challenge_data = serde_json::json!({ - "dns_record": dns_record, - "challenge_code": dns_value, - }); - - conn.set::<_, _, ()>(&dns_key, challenge_data.to_string()).await - .with_context(|| format!("Failed to write DNS challenge to Redis key: {}", dns_key))?; - - // Store DNS challenge timestamp - let timestamp = chrono::Utc::now(); - let timestamp_key = format!("{}:timestamp", dns_key); - conn.set::<_, _, ()>(×tamp_key, timestamp.to_rfc3339()).await - .with_context(|| format!("Failed to write DNS challenge timestamp to Redis key: {}", timestamp_key))?; - - tracing::info!("DNS challenge code saved to Redis: {} = {}", dns_record, dns_value); - Ok(()) - } - - async fn read_account_credentials(&self) -> Result> { - let mut conn = self.get_conn().await?; - // Use a shared key for account credentials (not per-domain) - // Get prefix from RedisManager if available - let prefix = crate::redis::RedisManager::get() - .map(|rm| rm.get_prefix().to_string()) - .unwrap_or_else(|_| "ssl-storage".to_string()); - let creds_key = format!("{}:acme:account_credentials", prefix); - let creds_key_clone = creds_key.clone(); - - let result: Option = conn.get(&creds_key).await - .with_context(|| format!("Failed to read account credentials from Redis key: {}", creds_key_clone))?; - - Ok(result) - } - - async fn write_account_credentials(&self, credentials: &str) -> Result<()> { - let mut conn = self.get_conn().await?; - // Use a shared key for account credentials (not per-domain) - // Get prefix from RedisManager if available - let prefix = crate::redis::RedisManager::get() - .map(|rm| rm.get_prefix().to_string()) - .unwrap_or_else(|_| "ssl-storage".to_string()); - let creds_key = format!("{}:acme:account_credentials", prefix); - let creds_key_clone = creds_key.clone(); - - conn.set::<_, _, ()>(&creds_key, credentials).await - .with_context(|| format!("Failed to write account credentials to Redis key: {}", creds_key_clone))?; - - Ok(()) - } - - async fn record_failure(&self, error: &str) -> Result<()> { - let mut conn = self.get_conn().await?; - let failure_key = self.metadata_key("cert_failure"); - let count_key = self.metadata_key("cert_failure_count"); - - // Read current failure count - let count: u32 = conn.get(&count_key).await.unwrap_or(0); - let new_count = count + 1; - - // Write failure record - let failure_data = serde_json::json!({ - "timestamp": chrono::Utc::now().to_rfc3339(), - "error": error, - "count": new_count, - }); - - conn.set::<_, _, ()>(&failure_key, failure_data.to_string()).await - .with_context(|| format!("Failed to write failure record to Redis key: {}", failure_key))?; - - // Write failure count - conn.set::<_, _, ()>(&count_key, new_count.to_string()).await - .with_context(|| format!("Failed to write failure count to Redis key: {}", count_key))?; - - Ok(()) - } - - async fn get_last_failure(&self) -> Result, String)>> { - let mut conn = self.get_conn().await?; - let failure_key = self.metadata_key("cert_failure"); - - let content: Option = conn.get(&failure_key).await - .with_context(|| format!("Failed to read failure record from Redis key: {}", failure_key))?; - - let content = match content { - Some(c) => c, - None => return Ok(None), - }; - - let failure_data: serde_json::Value = serde_json::from_str(&content) - .with_context(|| format!("Failed to parse failure record: {}", content))?; - - let timestamp_str = failure_data["timestamp"] - .as_str() - .ok_or_else(|| anyhow!("Missing timestamp in failure record"))?; - let timestamp = chrono::DateTime::parse_from_rfc3339(timestamp_str) - .with_context(|| format!("Failed to parse timestamp: {}", timestamp_str))? - .with_timezone(&chrono::Utc); - - let error = failure_data["error"] - .as_str() - .ok_or_else(|| anyhow!("Missing error in failure record"))? - .to_string(); - - Ok(Some((timestamp, error))) - } - - async fn clear_failure(&self) -> Result<()> { - let mut conn = self.get_conn().await?; - let failure_key = self.metadata_key("cert_failure"); - let count_key = self.metadata_key("cert_failure_count"); - - conn.del::<_, ()>(&failure_key).await.ok(); - conn.del::<_, ()>(&count_key).await.ok(); - - Ok(()) - } - - async fn get_failure_count(&self) -> Result { - let mut conn = self.get_conn().await?; - let count_key = self.metadata_key("cert_failure_count"); - - let count: Option = conn.get(&count_key).await - .with_context(|| format!("Failed to read failure count from Redis key: {}", count_key))?; - - let count = match count { - Some(c) => c.trim().parse::().unwrap_or(0), - None => 0, - }; - - Ok(count) - } - - async fn get_certificate_hash(&self) -> Result> { - let mut conn = self.get_conn().await?; - let hash_key = self.metadata_key("certificate_hash"); - - // Check if hash exists - let hash: Option = conn.get(&hash_key).await - .with_context(|| format!("Failed to read certificate hash from Redis key: {}", hash_key))?; - - if let Some(hash) = hash { - return Ok(Some(hash)); - } - - // Hash doesn't exist, but check if certificate exists - if !self.cert_exists().await { - return Ok(None); - } - - // Certificate exists but hash doesn't - generate it - let fullchain = self.read_fullchain().await?; - let key = self.read_key().await?; - - // Calculate SHA256 hash of fullchain + key - use sha2::{Sha256, Digest}; - let mut hasher = Sha256::new(); - hasher.update(&fullchain); - hasher.update(&key); - let hash = format!("{:x}", hasher.finalize()); - - // Store the hash for future use - conn.set::<_, _, ()>(&hash_key, &hash).await - .with_context(|| format!("Failed to write certificate hash to Redis key: {}", hash_key))?; - - Ok(Some(hash)) - } - - async fn get_challenge_timestamp(&self, token: &str) -> Result>> { - let mut conn = self.get_conn().await?; - let challenge_key = self.challenge_key(token); - let timestamp_key = format!("{}:timestamp", challenge_key); - - let content: Option = conn.get(×tamp_key).await - .with_context(|| format!("Failed to read challenge timestamp from Redis key: {}", timestamp_key))?; - - let content = match content { - Some(c) => c, - None => return Ok(None), - }; - - let timestamp = chrono::DateTime::parse_from_rfc3339(content.trim()) - .with_context(|| format!("Failed to parse challenge timestamp: {}", content))? - .with_timezone(&chrono::Utc); - - Ok(Some(timestamp)) - } - - async fn get_dns_challenge_timestamp(&self, _domain: &str) -> Result>> { - let mut conn = self.get_conn().await?; - let dns_key = self.dns_challenge_key(); - let timestamp_key = format!("{}:timestamp", dns_key); - - let content: Option = conn.get(×tamp_key).await - .with_context(|| format!("Failed to read DNS challenge timestamp from Redis key: {}", timestamp_key))?; - - let content = match content { - Some(c) => c, - None => return Ok(None), - }; - - let timestamp = chrono::DateTime::parse_from_rfc3339(content.trim()) - .with_context(|| format!("Failed to parse DNS challenge timestamp: {}", content))? - .with_timezone(&chrono::Utc); - - Ok(Some(timestamp)) - } - - async fn is_challenge_expired(&self, token: &str, max_ttl_seconds: u64) -> Result { - let timestamp = match self.get_challenge_timestamp(token).await? { - Some(ts) => ts, - None => return Ok(true), // No timestamp means expired - }; - - let now = chrono::Utc::now(); - let age = now - timestamp; - let age_seconds = age.num_seconds() as u64; - - Ok(age_seconds >= max_ttl_seconds) - } - - async fn is_dns_challenge_expired(&self, _domain: &str, max_ttl_seconds: u64) -> Result { - let timestamp = match self.get_dns_challenge_timestamp(_domain).await? { - Some(ts) => ts, - None => return Ok(true), // No timestamp means expired - }; - - let now = chrono::Utc::now(); - let age = now - timestamp; - let age_seconds = age.num_seconds() as u64; - - Ok(age_seconds >= max_ttl_seconds) - } - - async fn cleanup_expired_challenges(&self, max_ttl_seconds: u64) -> Result<()> { - let mut conn = self.get_conn().await?; - - // Get all challenge keys matching the pattern - let challenge_pattern = format!("{}:challenge:*", self.base_key); - let keys: Vec = conn.keys(&challenge_pattern).await - .with_context(|| format!("Failed to get challenge keys matching {}", challenge_pattern))?; - - for challenge_key in keys { - // Skip timestamp keys - if challenge_key.ends_with(":timestamp") { - continue; - } - - // Extract token from key (format: base_key:challenge:token) - if let Some(token) = challenge_key.split(':').last() { - if let Ok(expired) = self.is_challenge_expired(token, max_ttl_seconds).await { - if expired { - let timestamp_key = format!("{}:timestamp", challenge_key); - // Remove challenge and timestamp - conn.del::<_, ()>(&challenge_key).await.ok(); - conn.del::<_, ()>(×tamp_key).await.ok(); - } - } - } - } - - Ok(()) - } -} - +//! Redis storage backend implementation + +use crate::acme::Config; +use super::Storage; +use anyhow::{anyhow, Context, Result}; +use async_trait::async_trait; +use redis::AsyncCommands; +use std::path::PathBuf; +use std::sync::Arc; + +/// Redis storage backend +pub struct RedisStorage { + client: Arc, + base_key: String, + static_path: PathBuf, + // Optional shared connection manager from RedisManager (for connection pooling) + // ConnectionManager is Clone, so we can store a clone of it + shared_connection: Option, +} + +impl RedisStorage { + /// Create a new Redis storage backend + pub fn new(config: &Config) -> Result { + // Get Redis URL from config, environment, or use default + let redis_url = config.opts.redis_url.clone() + .or_else(|| std::env::var("REDIS_URL").ok()) + .unwrap_or_else(|| "redis://127.0.0.1:6379".to_string()); + + // Create Redis client (ConnectionManager will be reused from RedisManager if available) + let client = if let Some(ssl_config) = &config.opts.redis_ssl { + Self::create_client_with_ssl(&redis_url, ssl_config)? + } else { + redis::Client::open(redis_url.as_str()) + .with_context(|| format!("Failed to connect to Redis at {}", redis_url))? + }; + + // Get Redis prefix from RedisManager if available, otherwise use default + let prefix = crate::redis::RedisManager::get() + .map(|rm| rm.get_prefix().to_string()) + .unwrap_or_else(|_| "ssl-storage".to_string()); + + // Normalize domain: strip wildcard prefix (*.) for consistent Redis key naming + // This ensures wildcard certificates (*.example.com) are stored with the same key + // as when they're looked up (example.com) + let domain = config.opts.domain.clone(); + let normalized_domain = domain.strip_prefix("*.").unwrap_or(&domain); + let base_key = format!("{}:{}", prefix, normalized_domain); + + // Try to reuse the connection manager from RedisManager for connection pooling + // ConnectionManager is Clone and shares the underlying connection pool + let shared_connection = crate::redis::RedisManager::get() + .ok() + .map(|rm| rm.get_connection()); + + Ok(Self { + client: Arc::new(client), + base_key, + static_path: config.static_path.clone(), + shared_connection, + }) + } + + /// Create Redis client with custom SSL/TLS configuration + fn create_client_with_ssl(redis_url: &str, ssl_config: &crate::acme::config::RedisSslConfig) -> Result { + use native_tls::{Certificate, Identity, TlsConnector}; + + // Build TLS connector with custom certificates + let mut tls_builder = TlsConnector::builder(); + + // Load CA certificate if provided + if let Some(ca_cert_path) = &ssl_config.ca_cert_path { + let ca_cert_data = std::fs::read(ca_cert_path) + .with_context(|| format!("Failed to read CA certificate from {}", ca_cert_path))?; + let ca_cert = Certificate::from_pem(&ca_cert_data) + .with_context(|| format!("Failed to parse CA certificate from {}", ca_cert_path))?; + tls_builder.add_root_certificate(ca_cert); + tracing::info!("Loaded CA certificate from {}", ca_cert_path); + } + + // Load client certificate and key if provided + if let (Some(client_cert_path), Some(client_key_path)) = (&ssl_config.client_cert_path, &ssl_config.client_key_path) { + let client_cert_data = std::fs::read(client_cert_path) + .with_context(|| format!("Failed to read client certificate from {}", client_cert_path))?; + let client_key_data = std::fs::read(client_key_path) + .with_context(|| format!("Failed to read client key from {}", client_key_path))?; + + // Try to create identity from PEM format (cert + key) + let identity = Identity::from_pkcs8(&client_cert_data, &client_key_data) + .or_else(|_| { + // Try PEM format if PKCS#8 fails + Identity::from_pkcs12(&client_cert_data, "") + }) + .or_else(|_| { + // Try loading as separate PEM files + // Combine cert and key into a single PEM + let mut combined = client_cert_data.clone(); + combined.extend_from_slice(b"\n"); + combined.extend_from_slice(&client_key_data); + Identity::from_pkcs12(&combined, "") + }) + .with_context(|| format!("Failed to parse client certificate/key from {} and {}. Supported formats: PKCS#8, PKCS#12, or PEM", client_cert_path, client_key_path))?; + tls_builder.identity(identity); + tracing::info!("Loaded client certificate from {} and key from {}", client_cert_path, client_key_path); + } + + // Configure certificate verification + if ssl_config.insecure { + tls_builder.danger_accept_invalid_certs(true); + tls_builder.danger_accept_invalid_hostnames(true); + tracing::warn!("Redis SSL: Certificate verification disabled (insecure mode)"); + } + + let _tls_connector = tls_builder.build() + .with_context(|| "Failed to build TLS connector")?; + + // Note: The redis crate with tokio-native-tls-comp uses native-tls internally, + // but doesn't expose a way to pass a custom TlsConnector. However, when using + // rediss:// URLs, it will use the system trust store. For custom CA certificates, + // we need to add them to the system trust store or use a workaround. + // + // For now, we'll create the client normally. The TLS configuration above + // validates the certificates, but the redis crate will use its own TLS setup. + // + // TODO: The redis crate doesn't support custom TlsConnector directly. + // We may need to: + // 1. Add CA cert to system trust store (requires system-level changes) + // 2. Use a different Redis client that supports custom TLS + // 3. Wait for redis crate to support custom TLS configuration + + // For insecure mode, the redis crate should handle it via rediss:// URL + // For custom CA certs, we'll need to rely on the system trust store + // or use environment variables if the redis crate supports it + + let client = redis::Client::open(redis_url) + .with_context(|| format!("Failed to create Redis client with SSL config"))?; + + Ok(client) + } + + /// Get Redis connection (reuses pooled connection if available) + async fn get_conn(&self) -> Result { + use redis::aio::ConnectionManager; + + // Prefer shared connection from RedisManager (connection pooling) + // ConnectionManager is Clone and shares the underlying connection pool + if let Some(shared_conn) = &self.shared_connection { + return Ok(shared_conn.clone()); + } + + // Fallback: create new connection manager from client + // This still uses connection pooling internally via the client + let client = Arc::as_ref(&self.client); + ConnectionManager::new(client.clone()) + .await + .with_context(|| "Failed to get Redis connection") + } + + /// Get key for live certificate + fn live_key(&self, file_type: &str) -> String { + format!("{}:live:{}", self.base_key, file_type) + } + + + /// Get metadata key + fn metadata_key(&self, key: &str) -> String { + format!("{}:metadata:{}", self.base_key, key) + } + + /// Get key for challenge token + fn challenge_key(&self, token: &str) -> String { + format!("{}:challenge:{}", self.base_key, token) + } + + /// Get key for DNS challenge + fn dns_challenge_key(&self) -> String { + format!("{}:dns-challenge", self.base_key) + } + + /// Get key for distributed lock + fn lock_key(&self) -> String { + format!("{}:lock", self.base_key) + } + + /// Acquire a distributed lock for this domain + /// Returns true if lock was acquired, false if already locked + /// Lock expires after `ttl_seconds` to prevent deadlocks + pub async fn acquire_lock(&self, ttl_seconds: u64) -> Result { + let mut conn = self.get_conn().await?; + let lock_key = self.lock_key(); + + // Use SET with NX (only set if not exists) and EX (expiration) for atomic lock acquisition + let result: Option<()> = redis::cmd("SET") + .arg(&lock_key) + .arg("locked") + .arg("NX") // Only set if key doesn't exist + .arg("EX") // Set expiration + .arg(ttl_seconds) + .query_async(&mut conn) + .await + .with_context(|| format!("Failed to acquire lock for key: {}", lock_key))?; + + Ok(result.is_some()) + } + + /// Release the distributed lock for this domain + pub async fn release_lock(&self) -> Result<()> { + let mut conn = self.get_conn().await?; + let lock_key = self.lock_key(); + conn.del::<_, ()>(&lock_key).await + .with_context(|| format!("Failed to release lock for key: {}", lock_key))?; + Ok(()) + } + + /// Execute a function with a distributed lock + /// Returns Ok with default value if the lock cannot be acquired (skips operation) + pub async fn with_lock(&self, ttl_seconds: u64, f: F) -> Result + where + F: FnOnce() -> Fut, + Fut: std::future::Future>, + T: Default, + { + if !self.acquire_lock(ttl_seconds).await? { + tracing::warn!("Failed to acquire lock for domain - another instance is processing this domain. Skipping operation."); + return Ok(Default::default()); + } + + let result = f().await; + + // Always try to release the lock, even if f() returned an error + if let Err(e) = self.release_lock().await { + tracing::warn!("Failed to release lock: {}", e); + } + + result + } + + /// Delete all archived certificates (Redis doesn't need to keep old versions) + async fn delete_archived_certs(&self, conn: &mut redis::aio::ConnectionManager) -> Result<()> { + // Get all keys matching the archive pattern + let archive_pattern = format!("{}:archive:*", self.base_key); + let keys: Vec = conn.keys(&archive_pattern).await + .with_context(|| format!("Failed to get archive keys matching {}", archive_pattern))?; + + // Delete all archived keys + if !keys.is_empty() { + conn.del::<_, ()>(keys).await + .with_context(|| "Failed to delete archived certificates")?; + } + + Ok(()) + } +} + +#[async_trait] +impl Storage for RedisStorage { + async fn read_cert(&self) -> Result> { + let mut conn = self.get_conn().await?; + let key = self.live_key("cert"); + let data: Vec = conn.get(&key).await + .with_context(|| format!("Failed to read certificate from Redis key: {}", key))?; + Ok(data) + } + + async fn read_chain(&self) -> Result> { + let mut conn = self.get_conn().await?; + let key = self.live_key("chain"); + let data: Vec = conn.get(&key).await + .with_context(|| format!("Failed to read chain from Redis key: {}", key))?; + Ok(data) + } + + async fn read_fullchain(&self) -> Result> { + let mut conn = self.get_conn().await?; + let key = self.live_key("fullchain"); + let data: Vec = conn.get(&key).await + .with_context(|| format!("Failed to read fullchain from Redis key: {}", key))?; + Ok(data) + } + + async fn read_key(&self) -> Result> { + let mut conn = self.get_conn().await?; + let key = self.live_key("privkey"); + let data: Vec = conn.get(&key).await + .with_context(|| format!("Failed to read private key from Redis key: {}", key))?; + Ok(data) + } + + async fn write_certs(&self, cert: &[u8], chain: &[u8], key: &[u8]) -> Result<()> { + tracing::debug!("Connecting to Redis for certificate storage..."); + let mut conn = self.get_conn().await + .with_context(|| "Failed to get Redis connection")?; + tracing::debug!("Redis connection established"); + + // Combine cert and chain to create fullchain + let mut fullchain = cert.to_vec(); + fullchain.extend_from_slice(chain); + tracing::debug!("Combined certificate chain (cert: {} bytes, chain: {} bytes, fullchain: {} bytes)", + cert.len(), chain.len(), fullchain.len()); + + // Calculate SHA256 hash of fullchain + key for change detection + use sha2::{Sha256, Digest}; + let mut hasher = Sha256::new(); + hasher.update(&fullchain); + hasher.update(key); + let hash = format!("{:x}", hasher.finalize()); + tracing::debug!("Calculated certificate hash: {}", hash); + + // Delete old archived certificates (keep only live/current version) + self.delete_archived_certs(&mut conn).await?; + + // Update live keys (current version only, no archive in Redis) + let live_files = [ + ("cert", cert), + ("chain", chain), + ("fullchain", &fullchain), + ("privkey", key), + ]; + + for (file_type, content) in live_files.iter() { + let live_key = self.live_key(file_type); + tracing::debug!("Writing {} to Redis key: {} ({} bytes)", file_type, live_key, content.len()); + conn.set::<_, _, ()>(&live_key, content).await + .with_context(|| format!("Failed to write live {} to Redis key: {}", file_type, live_key))?; + tracing::debug!("Successfully wrote {} to Redis key: {}", file_type, live_key); + } + + // Store certificate hash for change detection + let hash_key = self.metadata_key("certificate_hash"); + conn.set::<_, _, ()>(&hash_key, &hash).await + .with_context(|| format!("Failed to write certificate hash to Redis key: {}", hash_key))?; + tracing::debug!("Stored certificate hash: {} at key: {}", hash, hash_key); + + tracing::info!("All certificates written successfully to Redis for domain: {}", self.base_key); + Ok(()) + } + + async fn cert_exists(&self) -> bool { + let mut conn = match self.get_conn().await { + Ok(c) => c, + Err(_) => return false, + }; + + let key = self.live_key("cert"); + conn.exists(&key).await.unwrap_or(false) + } + + async fn read_created_at(&self) -> Result> { + let mut conn = self.get_conn().await?; + let key = self.metadata_key("created_at"); + let content: String = conn.get(&key).await + .with_context(|| format!("Failed to read created_at from Redis key: {}", key))?; + content + .parse::>() + .with_context(|| format!("Failed to parse created_at: {}", content)) + } + + async fn write_created_at(&self, created_at: chrono::DateTime) -> Result<()> { + let mut conn = self.get_conn().await?; + let key = self.metadata_key("created_at"); + conn.set::<_, _, ()>(&key, created_at.to_string()).await + .with_context(|| format!("Failed to write created_at to Redis key: {}", key))?; + Ok(()) + } + + async fn write_challenge(&self, token: &str, key_auth: &str) -> Result<()> { + // Store challenge token in Redis + let mut conn = self.get_conn().await?; + let challenge_key = self.challenge_key(token); + conn.set::<_, _, ()>(&challenge_key, key_auth).await + .with_context(|| format!("Failed to write challenge token to Redis key: {}", challenge_key))?; + + // Store challenge timestamp + let timestamp = chrono::Utc::now(); + let timestamp_key = format!("{}:timestamp", challenge_key); + conn.set::<_, _, ()>(×tamp_key, timestamp.to_rfc3339()).await + .with_context(|| format!("Failed to write challenge timestamp to Redis key: {}", timestamp_key))?; + + // Also write to filesystem for HTTP-01 challenge serving + // The HTTP server needs to serve these files + // Write challenge files to a shared location (not per-domain) + // This allows the HTTP server to serve them from a single base path + let base_path = self.static_path.parent() + .ok_or_else(|| anyhow!("Cannot get parent path from static_path"))? + .to_path_buf(); + + let mut well_known_folder = base_path.clone(); + well_known_folder.push("well-known"); + tokio::fs::create_dir_all(&well_known_folder) + .await + .with_context(|| format!("Failed to create well-known directory {:?}", well_known_folder))?; + + let mut challenge_path = well_known_folder.clone(); + challenge_path.push("acme-challenge"); + tokio::fs::create_dir_all(&challenge_path) + .await + .with_context(|| format!("Failed to create acme-challenge directory {:?}", challenge_path))?; + + challenge_path.push(token); + tokio::fs::write(&challenge_path, key_auth) + .await + .with_context(|| format!("Failed to write challenge file {:?}", challenge_path))?; + + Ok(()) + } + + fn static_path(&self) -> PathBuf { + self.static_path.clone() + } + + fn read_fullchain_sync(&self) -> Option>> { + // Redis doesn't support sync operations easily, return None + None + } + + fn read_key_sync(&self) -> Option>> { + // Redis doesn't support sync operations easily, return None + None + } + + async fn write_dns_challenge(&self, _domain: &str, dns_record: &str, dns_value: &str) -> Result<()> { + // Store DNS challenge code in Redis + let mut conn = self.get_conn().await?; + let dns_key = self.dns_challenge_key(); + + // Store as JSON with dns_record and challenge_code + let challenge_data = serde_json::json!({ + "dns_record": dns_record, + "challenge_code": dns_value, + }); + + conn.set::<_, _, ()>(&dns_key, challenge_data.to_string()).await + .with_context(|| format!("Failed to write DNS challenge to Redis key: {}", dns_key))?; + + // Store DNS challenge timestamp + let timestamp = chrono::Utc::now(); + let timestamp_key = format!("{}:timestamp", dns_key); + conn.set::<_, _, ()>(×tamp_key, timestamp.to_rfc3339()).await + .with_context(|| format!("Failed to write DNS challenge timestamp to Redis key: {}", timestamp_key))?; + + tracing::info!("DNS challenge code saved to Redis: {} = {}", dns_record, dns_value); + Ok(()) + } + + async fn read_account_credentials(&self) -> Result> { + let mut conn = self.get_conn().await?; + // Use a shared key for account credentials (not per-domain) + // Get prefix from RedisManager if available + let prefix = crate::redis::RedisManager::get() + .map(|rm| rm.get_prefix().to_string()) + .unwrap_or_else(|_| "ssl-storage".to_string()); + let creds_key = format!("{}:acme:account_credentials", prefix); + let creds_key_clone = creds_key.clone(); + + let result: Option = conn.get(&creds_key).await + .with_context(|| format!("Failed to read account credentials from Redis key: {}", creds_key_clone))?; + + Ok(result) + } + + async fn write_account_credentials(&self, credentials: &str) -> Result<()> { + let mut conn = self.get_conn().await?; + // Use a shared key for account credentials (not per-domain) + // Get prefix from RedisManager if available + let prefix = crate::redis::RedisManager::get() + .map(|rm| rm.get_prefix().to_string()) + .unwrap_or_else(|_| "ssl-storage".to_string()); + let creds_key = format!("{}:acme:account_credentials", prefix); + let creds_key_clone = creds_key.clone(); + + conn.set::<_, _, ()>(&creds_key, credentials).await + .with_context(|| format!("Failed to write account credentials to Redis key: {}", creds_key_clone))?; + + Ok(()) + } + + async fn record_failure(&self, error: &str) -> Result<()> { + let mut conn = self.get_conn().await?; + let failure_key = self.metadata_key("cert_failure"); + let count_key = self.metadata_key("cert_failure_count"); + + // Read current failure count + let count: u32 = conn.get(&count_key).await.unwrap_or(0); + let new_count = count + 1; + + // Write failure record + let failure_data = serde_json::json!({ + "timestamp": chrono::Utc::now().to_rfc3339(), + "error": error, + "count": new_count, + }); + + conn.set::<_, _, ()>(&failure_key, failure_data.to_string()).await + .with_context(|| format!("Failed to write failure record to Redis key: {}", failure_key))?; + + // Write failure count + conn.set::<_, _, ()>(&count_key, new_count.to_string()).await + .with_context(|| format!("Failed to write failure count to Redis key: {}", count_key))?; + + Ok(()) + } + + async fn get_last_failure(&self) -> Result, String)>> { + let mut conn = self.get_conn().await?; + let failure_key = self.metadata_key("cert_failure"); + + let content: Option = conn.get(&failure_key).await + .with_context(|| format!("Failed to read failure record from Redis key: {}", failure_key))?; + + let content = match content { + Some(c) => c, + None => return Ok(None), + }; + + let failure_data: serde_json::Value = serde_json::from_str(&content) + .with_context(|| format!("Failed to parse failure record: {}", content))?; + + let timestamp_str = failure_data["timestamp"] + .as_str() + .ok_or_else(|| anyhow!("Missing timestamp in failure record"))?; + let timestamp = chrono::DateTime::parse_from_rfc3339(timestamp_str) + .with_context(|| format!("Failed to parse timestamp: {}", timestamp_str))? + .with_timezone(&chrono::Utc); + + let error = failure_data["error"] + .as_str() + .ok_or_else(|| anyhow!("Missing error in failure record"))? + .to_string(); + + Ok(Some((timestamp, error))) + } + + async fn clear_failure(&self) -> Result<()> { + let mut conn = self.get_conn().await?; + let failure_key = self.metadata_key("cert_failure"); + let count_key = self.metadata_key("cert_failure_count"); + + conn.del::<_, ()>(&failure_key).await.ok(); + conn.del::<_, ()>(&count_key).await.ok(); + + Ok(()) + } + + async fn get_failure_count(&self) -> Result { + let mut conn = self.get_conn().await?; + let count_key = self.metadata_key("cert_failure_count"); + + let count: Option = conn.get(&count_key).await + .with_context(|| format!("Failed to read failure count from Redis key: {}", count_key))?; + + let count = match count { + Some(c) => c.trim().parse::().unwrap_or(0), + None => 0, + }; + + Ok(count) + } + + async fn get_certificate_hash(&self) -> Result> { + let mut conn = self.get_conn().await?; + let hash_key = self.metadata_key("certificate_hash"); + + // Check if hash exists + let hash: Option = conn.get(&hash_key).await + .with_context(|| format!("Failed to read certificate hash from Redis key: {}", hash_key))?; + + if let Some(hash) = hash { + return Ok(Some(hash)); + } + + // Hash doesn't exist, but check if certificate exists + if !self.cert_exists().await { + return Ok(None); + } + + // Certificate exists but hash doesn't - generate it + let fullchain = self.read_fullchain().await?; + let key = self.read_key().await?; + + // Calculate SHA256 hash of fullchain + key + use sha2::{Sha256, Digest}; + let mut hasher = Sha256::new(); + hasher.update(&fullchain); + hasher.update(&key); + let hash = format!("{:x}", hasher.finalize()); + + // Store the hash for future use + conn.set::<_, _, ()>(&hash_key, &hash).await + .with_context(|| format!("Failed to write certificate hash to Redis key: {}", hash_key))?; + + Ok(Some(hash)) + } + + async fn get_challenge_timestamp(&self, token: &str) -> Result>> { + let mut conn = self.get_conn().await?; + let challenge_key = self.challenge_key(token); + let timestamp_key = format!("{}:timestamp", challenge_key); + + let content: Option = conn.get(×tamp_key).await + .with_context(|| format!("Failed to read challenge timestamp from Redis key: {}", timestamp_key))?; + + let content = match content { + Some(c) => c, + None => return Ok(None), + }; + + let timestamp = chrono::DateTime::parse_from_rfc3339(content.trim()) + .with_context(|| format!("Failed to parse challenge timestamp: {}", content))? + .with_timezone(&chrono::Utc); + + Ok(Some(timestamp)) + } + + async fn get_dns_challenge_timestamp(&self, _domain: &str) -> Result>> { + let mut conn = self.get_conn().await?; + let dns_key = self.dns_challenge_key(); + let timestamp_key = format!("{}:timestamp", dns_key); + + let content: Option = conn.get(×tamp_key).await + .with_context(|| format!("Failed to read DNS challenge timestamp from Redis key: {}", timestamp_key))?; + + let content = match content { + Some(c) => c, + None => return Ok(None), + }; + + let timestamp = chrono::DateTime::parse_from_rfc3339(content.trim()) + .with_context(|| format!("Failed to parse DNS challenge timestamp: {}", content))? + .with_timezone(&chrono::Utc); + + Ok(Some(timestamp)) + } + + async fn is_challenge_expired(&self, token: &str, max_ttl_seconds: u64) -> Result { + let timestamp = match self.get_challenge_timestamp(token).await? { + Some(ts) => ts, + None => return Ok(true), // No timestamp means expired + }; + + let now = chrono::Utc::now(); + let age = now - timestamp; + let age_seconds = age.num_seconds() as u64; + + Ok(age_seconds >= max_ttl_seconds) + } + + async fn is_dns_challenge_expired(&self, _domain: &str, max_ttl_seconds: u64) -> Result { + let timestamp = match self.get_dns_challenge_timestamp(_domain).await? { + Some(ts) => ts, + None => return Ok(true), // No timestamp means expired + }; + + let now = chrono::Utc::now(); + let age = now - timestamp; + let age_seconds = age.num_seconds() as u64; + + Ok(age_seconds >= max_ttl_seconds) + } + + async fn cleanup_expired_challenges(&self, max_ttl_seconds: u64) -> Result<()> { + let mut conn = self.get_conn().await?; + + // Get all challenge keys matching the pattern + let challenge_pattern = format!("{}:challenge:*", self.base_key); + let keys: Vec = conn.keys(&challenge_pattern).await + .with_context(|| format!("Failed to get challenge keys matching {}", challenge_pattern))?; + + for challenge_key in keys { + // Skip timestamp keys + if challenge_key.ends_with(":timestamp") { + continue; + } + + // Extract token from key (format: base_key:challenge:token) + if let Some(token) = challenge_key.split(':').last() { + if let Ok(expired) = self.is_challenge_expired(token, max_ttl_seconds).await { + if expired { + let timestamp_key = format!("{}:timestamp", challenge_key); + // Remove challenge and timestamp + conn.del::<_, ()>(&challenge_key).await.ok(); + conn.del::<_, ()>(×tamp_key).await.ok(); + } + } + } + } + + Ok(()) + } +} + diff --git a/src/acme/upstreams_reader.rs b/src/acme/upstreams_reader.rs index 92619c0..aec1c98 100644 --- a/src/acme/upstreams_reader.rs +++ b/src/acme/upstreams_reader.rs @@ -1,140 +1,140 @@ -//! Domain reader that reads domains from upstreams.yaml configuration - -use anyhow::{Context, Result}; -use crate::acme::domain_reader::{DomainConfig, DomainReader}; -use std::path::PathBuf; -use std::sync::Arc; -use tokio::sync::RwLock; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UpstreamsAcmeConfig { - /// Challenge type: "http-01" or "dns-01" - #[serde(default = "default_challenge_type")] - pub challenge_type: String, - /// Email for ACME account (optional, can be set globally) - pub email: Option, - /// Whether this is a wildcard domain - #[serde(default)] - pub wildcard: bool, -} - -fn default_challenge_type() -> String { - "http-01".to_string() -} - -/// Domain reader that reads from upstreams.yaml -pub struct UpstreamsDomainReader { - upstreams_path: PathBuf, - cached_domains: Arc>>>, - /// Global email for ACME (from config) - global_email: Option, -} - -impl UpstreamsDomainReader { - pub fn new(upstreams_path: impl Into, global_email: Option) -> Self { - Self { - upstreams_path: upstreams_path.into(), - cached_domains: Arc::new(RwLock::new(None)), - global_email, - } - } - - async fn fetch_domains(&self) -> Result> { - use serde_yaml; - - let mut domains = Vec::new(); - - // Read and parse upstreams.yaml directly to get ACME config - let yaml_content = tokio::fs::read_to_string(&self.upstreams_path).await - .with_context(|| format!("Failed to read upstreams file: {:?}", self.upstreams_path))?; - - let parsed: crate::utils::structs::Config = serde_yaml::from_str(&yaml_content) - .with_context(|| format!("Failed to parse upstreams YAML: {:?}", self.upstreams_path))?; - - if let Some(upstreams) = &parsed.upstreams { - for (hostname, host_config) in upstreams { - // Only include domains that need certificates - if !host_config.needs_certificate() { - continue; - } - - let is_wildcard = hostname.starts_with("*."); - let acme_wildcard = host_config.acme.as_ref().map(|a| a.wildcard).unwrap_or(false); - - // Determine challenge type from ACME config or auto-detect - let challenge_type = if let Some(acme_config) = &host_config.acme { - acme_config.challenge_type.clone() - } else if is_wildcard { - "dns-01".to_string() - } else { - "http-01".to_string() - }; - - // Determine email from ACME config or use global - let email = if let Some(acme_config) = &host_config.acme { - acme_config.email.clone().or_else(|| self.global_email.clone()) - } else { - self.global_email.clone() - }; - - // Determine the domain to use for ACME request - // If wildcard is true and certificate is specified, use *.{certificate} - // Otherwise, use the hostname as-is - let acme_domain = if acme_wildcard && !is_wildcard { - // Wildcard is set in config but hostname doesn't start with *. - // Use certificate domain if available, otherwise use hostname - if let Some(cert_domain) = &host_config.certificate { - format!("*.{}", cert_domain) - } else { - // Extract base domain from hostname (remove subdomain) - // For dev01.sub.example.com -> sub.example.com - let parts: Vec<&str> = hostname.split('.').collect(); - if parts.len() >= 3 { - // Take last 2 parts as base domain - format!("*.{}.{}", parts[parts.len() - 2], parts[parts.len() - 1]) - } else { - format!("*.{}", hostname) - } - } - } else { - hostname.clone() - }; - - domains.push(DomainConfig { - domain: acme_domain, - email, - dns: challenge_type == "dns-01", - wildcard: is_wildcard || acme_wildcard, - }); - } - } - - Ok(domains) - } -} - -#[async_trait::async_trait] -impl DomainReader for UpstreamsDomainReader { - async fn read_domains(&self) -> Result> { - // Try cache first - { - let cache = self.cached_domains.read().await; - if let Some(domains) = cache.as_ref() { - return Ok(domains.clone()); - } - } - - // Fetch fresh data - let domains = self.fetch_domains().await?; - - // Update cache - { - let mut cache = self.cached_domains.write().await; - *cache = Some(domains.clone()); - } - - Ok(domains) - } -} - +//! Domain reader that reads domains from upstreams.yaml configuration + +use anyhow::{Context, Result}; +use crate::acme::domain_reader::{DomainConfig, DomainReader}; +use std::path::PathBuf; +use std::sync::Arc; +use tokio::sync::RwLock; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UpstreamsAcmeConfig { + /// Challenge type: "http-01" or "dns-01" + #[serde(default = "default_challenge_type")] + pub challenge_type: String, + /// Email for ACME account (optional, can be set globally) + pub email: Option, + /// Whether this is a wildcard domain + #[serde(default)] + pub wildcard: bool, +} + +fn default_challenge_type() -> String { + "http-01".to_string() +} + +/// Domain reader that reads from upstreams.yaml +pub struct UpstreamsDomainReader { + upstreams_path: PathBuf, + cached_domains: Arc>>>, + /// Global email for ACME (from config) + global_email: Option, +} + +impl UpstreamsDomainReader { + pub fn new(upstreams_path: impl Into, global_email: Option) -> Self { + Self { + upstreams_path: upstreams_path.into(), + cached_domains: Arc::new(RwLock::new(None)), + global_email, + } + } + + async fn fetch_domains(&self) -> Result> { + use serde_yaml; + + let mut domains = Vec::new(); + + // Read and parse upstreams.yaml directly to get ACME config + let yaml_content = tokio::fs::read_to_string(&self.upstreams_path).await + .with_context(|| format!("Failed to read upstreams file: {:?}", self.upstreams_path))?; + + let parsed: crate::utils::structs::Config = serde_yaml::from_str(&yaml_content) + .with_context(|| format!("Failed to parse upstreams YAML: {:?}", self.upstreams_path))?; + + if let Some(upstreams) = &parsed.upstreams { + for (hostname, host_config) in upstreams { + // Only include domains that need certificates + if !host_config.needs_certificate() { + continue; + } + + let is_wildcard = hostname.starts_with("*."); + let acme_wildcard = host_config.acme.as_ref().map(|a| a.wildcard).unwrap_or(false); + + // Determine challenge type from ACME config or auto-detect + let challenge_type = if let Some(acme_config) = &host_config.acme { + acme_config.challenge_type.clone() + } else if is_wildcard { + "dns-01".to_string() + } else { + "http-01".to_string() + }; + + // Determine email from ACME config or use global + let email = if let Some(acme_config) = &host_config.acme { + acme_config.email.clone().or_else(|| self.global_email.clone()) + } else { + self.global_email.clone() + }; + + // Determine the domain to use for ACME request + // If wildcard is true and certificate is specified, use *.{certificate} + // Otherwise, use the hostname as-is + let acme_domain = if acme_wildcard && !is_wildcard { + // Wildcard is set in config but hostname doesn't start with *. + // Use certificate domain if available, otherwise use hostname + if let Some(cert_domain) = &host_config.certificate { + format!("*.{}", cert_domain) + } else { + // Extract base domain from hostname (remove subdomain) + // For dev01.sub.example.com -> sub.example.com + let parts: Vec<&str> = hostname.split('.').collect(); + if parts.len() >= 3 { + // Take last 2 parts as base domain + format!("*.{}.{}", parts[parts.len() - 2], parts[parts.len() - 1]) + } else { + format!("*.{}", hostname) + } + } + } else { + hostname.clone() + }; + + domains.push(DomainConfig { + domain: acme_domain, + email, + dns: challenge_type == "dns-01", + wildcard: is_wildcard || acme_wildcard, + }); + } + } + + Ok(domains) + } +} + +#[async_trait::async_trait] +impl DomainReader for UpstreamsDomainReader { + async fn read_domains(&self) -> Result> { + // Try cache first + { + let cache = self.cached_domains.read().await; + if let Some(domains) = cache.as_ref() { + return Ok(domains.clone()); + } + } + + // Fetch fresh data + let domains = self.fetch_domains().await?; + + // Update cache + { + let mut cache = self.cached_domains.write().await; + *cache = Some(domains.clone()); + } + + Ok(domains) + } +} + diff --git a/src/actions/captcha.rs b/src/actions/captcha.rs index 74644a8..92a3232 100644 --- a/src/actions/captcha.rs +++ b/src/actions/captcha.rs @@ -1,1011 +1,1011 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use anyhow::{Context, Result}; -use chrono::Utc; -use redis::AsyncCommands; -use serde::{Deserialize, Serialize}; -use tokio::sync::{RwLock, OnceCell}; -use jsonwebtoken::{encode, decode, Header, Algorithm, Validation, EncodingKey, DecodingKey}; -use uuid::Uuid; - -use crate::redis::RedisManager; -use crate::http_client::get_global_reqwest_client; - -/// Captcha provider types supported by Gen0Sec -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, clap::ValueEnum)] -pub enum CaptchaProvider { - #[serde(rename = "hcaptcha")] - HCaptcha, - #[serde(rename = "recaptcha")] - ReCaptcha, - #[serde(rename = "turnstile")] - Turnstile, -} - -impl Default for CaptchaProvider { - fn default() -> Self { - CaptchaProvider::HCaptcha - } -} - -impl std::str::FromStr for CaptchaProvider { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "hcaptcha" => Ok(CaptchaProvider::HCaptcha), - "recaptcha" => Ok(CaptchaProvider::ReCaptcha), - "turnstile" => Ok(CaptchaProvider::Turnstile), - _ => Err(anyhow::anyhow!("Invalid captcha provider: {}", s)), - } - } -} - -/// Captcha validation request -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CaptchaValidationRequest { - pub response_token: String, - pub ip_address: String, - pub user_agent: Option, - pub site_key: String, - pub secret_key: String, - pub provider: CaptchaProvider, -} - -/// Captcha validation response -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CaptchaValidationResponse { - pub success: bool, - pub error_codes: Option>, - pub challenge_ts: Option, - pub hostname: Option, - pub score: Option, - pub action: Option, -} - -/// JWT Claims for captcha tokens -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CaptchaClaims { - /// Standard JWT claims - pub sub: String, // Subject (user identifier) - pub iss: String, // Issuer - pub aud: String, // Audience - pub exp: i64, // Expiration time - pub iat: i64, // Issued at - pub jti: String, // JWT ID (unique identifier) - - /// Custom captcha claims - pub ip_address: String, - pub user_agent: String, - pub ja4_fingerprint: Option, - pub captcha_provider: String, - pub captcha_validated: bool, -} - -/// Captcha token with JWT-based security -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CaptchaToken { - pub token: String, - pub claims: CaptchaClaims, -} - -/// Cached captcha validation result -#[derive(Debug, Clone)] -pub struct CachedCaptchaResult { - pub is_valid: bool, - pub expires_at: Instant, -} - -/// Captcha action configuration -#[derive(Debug, Clone)] -pub struct CaptchaConfig { - pub site_key: String, - pub secret_key: String, - pub jwt_secret: String, - pub provider: CaptchaProvider, - pub token_ttl_seconds: u64, - pub validation_cache_ttl_seconds: u64, -} - -/// Captcha client for validation and token management -pub struct CaptchaClient { - config: CaptchaConfig, - validation_cache: Arc>>, - validated_tokens: Arc>>, // JTI -> expiration time -} - -impl CaptchaClient { - pub fn new( - config: CaptchaConfig, - ) -> Self { - Self { - config, - validation_cache: Arc::new(RwLock::new(HashMap::new())), - validated_tokens: Arc::new(RwLock::new(HashMap::new())), - } - } - - /// Validate a captcha response token - pub async fn validate_captcha(&self, request: CaptchaValidationRequest) -> Result { - log::info!("Starting captcha validation for IP: {}, provider: {:?}", - request.ip_address, self.config.provider); - - // Check if captcha response is provided - if request.response_token.is_empty() { - log::warn!("No captcha response provided for IP: {}", request.ip_address); - return Ok(false); - } - - log::debug!("Captcha response token length: {}", request.response_token.len()); - - // Check validation cache first - let cache_key = format!("{}:{}", request.response_token, request.ip_address); - if let Some(cached) = self.get_validation_cache(&cache_key).await { - if cached.expires_at > Instant::now() { - log::debug!("Captcha validation for {} found in cache", request.ip_address); - return Ok(cached.is_valid); - } else { - self.remove_validation_cache(&cache_key).await; - } - } - - // Validate with provider API - let is_valid = match self.config.provider { - CaptchaProvider::HCaptcha => self.validate_hcaptcha(&request).await?, - CaptchaProvider::ReCaptcha => self.validate_recaptcha(&request).await?, - CaptchaProvider::Turnstile => self.validate_turnstile(&request).await?, - }; - - log::info!("Captcha validation result for IP {}: {}", request.ip_address, is_valid); - - // Cache the result - self.set_validation_cache(&cache_key, is_valid).await; - - Ok(is_valid) - } - - /// Generate a secure JWT captcha token - pub async fn generate_token( - &self, - ip_address: String, - user_agent: String, - ja4_fingerprint: Option, - ) -> Result { - let now = Utc::now(); - let exp = now + chrono::Duration::seconds(self.config.token_ttl_seconds as i64); - let jti = Uuid::new_v4().to_string(); - - let claims = CaptchaClaims { - sub: format!("captcha:{}", ip_address), - iss: "arxignis-synapse".to_string(), - aud: "captcha-validation".to_string(), - exp: exp.timestamp(), - iat: now.timestamp(), - jti: jti.clone(), - ip_address: ip_address.clone(), - user_agent: user_agent.clone(), - ja4_fingerprint, - captcha_provider: format!("{:?}", self.config.provider), - captcha_validated: false, - }; - - let header = Header::new(Algorithm::HS256); - let encoding_key = EncodingKey::from_secret(self.config.jwt_secret.as_bytes()); - - let token = encode(&header, &claims, &encoding_key) - .context("Failed to encode JWT token")?; - - let captcha_token = CaptchaToken { - token: token.clone(), - claims: claims.clone(), - }; - - // Store token in Redis for validation (optional, JWT is self-contained) - if let Ok(redis_manager) = RedisManager::get() { - let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), jti); - let mut redis = redis_manager.get_connection(); - let token_data = serde_json::to_string(&captcha_token) - .context("Failed to serialize captcha token")?; - - let _: () = redis - .set_ex(&key, token_data, self.config.token_ttl_seconds) - .await - .context("Failed to store captcha token in Redis")?; - } - - Ok(captcha_token) - } - - /// Validate a JWT captcha token - pub async fn validate_token(&self, token: &str, ip_address: &str, user_agent: &str) -> Result { - let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); - let mut validation = Validation::new(Algorithm::HS256); - validation.set_audience(&["captcha-validation"]); - - match decode::(token, &decoding_key, &validation) { - Ok(token_data) => { - let claims = token_data.claims; - - // Check if token is expired (JWT handles this automatically, but double-check) - let now = Utc::now().timestamp(); - if claims.exp < now { - log::debug!("JWT token expired"); - return Ok(false); - } - - // Verify IP and User-Agent binding - if claims.ip_address != ip_address || claims.user_agent != user_agent { - log::warn!("JWT token validation failed: IP or User-Agent mismatch"); - return Ok(false); - } - - // Check Redis first for updated token state - let mut captcha_validated = claims.captcha_validated; - log::debug!("Initial JWT token captcha_validated: {}", captcha_validated); - - // Check in-memory cache first (faster) - { - let validated_tokens = self.validated_tokens.read().await; - if let Some(expiration) = validated_tokens.get(&claims.jti) { - if *expiration > Instant::now() { - captcha_validated = true; - log::debug!("Found validated token JTI {} in memory cache", claims.jti); - } else { - log::debug!("Token JTI {} expired in memory cache", claims.jti); - } - } - } - - // If not found in memory cache, check Redis - if !captcha_validated { - if let Ok(redis_manager) = RedisManager::get() { - let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), claims.jti); - log::debug!("Looking up token in Redis with key: {}", key); - - let mut redis = redis_manager.get_connection(); - match redis.get::<_, String>(&key).await { - Ok(token_data_str) => { - log::debug!("Found token data in Redis: {}", token_data_str); - if let Ok(updated_token) = serde_json::from_str::(&token_data_str) { - captcha_validated = updated_token.claims.captcha_validated; - log::debug!("Updated captcha_validated from Redis: {}", captcha_validated); - - // Update memory cache if found in Redis - if captcha_validated { - let expiration = Instant::now() + Duration::from_secs(self.config.token_ttl_seconds); - let mut validated_tokens = self.validated_tokens.write().await; - validated_tokens.insert(claims.jti.clone(), expiration); - } - } else { - log::warn!("Failed to parse token data from Redis"); - } - } - Err(e) => { - log::debug!("Redis token lookup failed for JTI {}: {}", claims.jti, e); - } - } - } else { - log::debug!("Redis manager not available"); - } - } - - // Check if captcha was validated (either from JWT or Redis) - if !captcha_validated { - log::debug!("JWT token not validated for captcha"); - return Ok(false); - } - - // Optional: Check Redis blacklist for revoked tokens - if let Ok(redis_manager) = RedisManager::get() { - let blacklist_key = format!("{}:captcha_blacklist:{}", redis_manager.create_namespace("captcha"), claims.jti); - let mut redis = redis_manager.get_connection(); - match redis.exists::<_, bool>(&blacklist_key).await { - Ok(true) => { - log::debug!("JWT token {} is blacklisted", claims.jti); - return Ok(false); - } - Ok(false) => { - // Token not blacklisted, continue validation - } - Err(e) => { - log::warn!("Redis blacklist check error for JWT {}: {}", claims.jti, e); - // Continue validation despite Redis error - } - } - } - - Ok(true) - } - Err(e) => { - log::warn!("JWT token validation failed: {}", e); - Ok(false) - } - } - } - - /// Mark a JWT token as validated after successful captcha completion - pub async fn mark_token_validated(&self, token: &str) -> Result<()> { - let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); - let mut validation = Validation::new(Algorithm::HS256); - validation.set_audience(&["captcha-validation"]); - - match decode::(token, &decoding_key, &validation) { - Ok(token_data) => { - let claims = token_data.claims; - - // Store the JTI as validated in memory cache - let expiration = Instant::now() + Duration::from_secs(self.config.token_ttl_seconds); - { - let mut validated_tokens = self.validated_tokens.write().await; - validated_tokens.insert(claims.jti.clone(), expiration); - log::debug!("Marked token JTI {} as validated, expires at {:?}", claims.jti, expiration); - } - - // Also update Redis cache if available (for persistence across restarts) - if let Ok(redis_manager) = RedisManager::get() { - let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), claims.jti); - log::debug!("Storing updated token in Redis with key: {}", key); - - let mut redis = redis_manager.get_connection(); - let mut updated_claims = claims.clone(); - updated_claims.captcha_validated = true; - - let updated_captcha_token = CaptchaToken { - token: token.to_string(), - claims: updated_claims, - }; - let token_data = serde_json::to_string(&updated_captcha_token) - .context("Failed to serialize updated captcha token")?; - - log::debug!("Token data to store: {}", token_data); - - let _: () = redis - .set_ex(&key, token_data, self.config.token_ttl_seconds) - .await - .context("Failed to update captcha token in Redis")?; - - log::debug!("Successfully stored updated token in Redis"); - } else { - log::debug!("Redis manager not available for token storage"); - } - - Ok(()) - } - Err(e) => { - log::warn!("Failed to decode JWT token for validation marking: {}", e); - Err(anyhow::anyhow!("Invalid JWT token: {}", e)) - } - } - } - - /// Revoke a JWT token by adding it to blacklist - pub async fn revoke_token(&self, token: &str) -> Result<()> { - let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); - let mut validation = Validation::new(Algorithm::HS256); - validation.set_audience(&["captcha-validation"]); - - match decode::(token, &decoding_key, &validation) { - Ok(token_data) => { - let claims = token_data.claims; - - // Add to Redis blacklist - if let Ok(redis_manager) = RedisManager::get() { - let blacklist_key = format!("{}:captcha_blacklist:{}", redis_manager.create_namespace("captcha"), claims.jti); - let mut redis = redis_manager.get_connection(); - let _: () = redis - .set_ex(&blacklist_key, "revoked", self.config.token_ttl_seconds) - .await - .context("Failed to add token to blacklist")?; - } - - Ok(()) - } - Err(e) => { - log::warn!("Failed to decode JWT token for revocation: {}", e); - Err(anyhow::anyhow!("Invalid JWT token: {}", e)) - } - } - } - - /// Apply captcha challenge (return HTML form) - pub fn apply_captcha_challenge(&self, site_key: &str) -> String { - self.render_captcha_template(site_key, None) - } - - /// Apply captcha challenge with JWT token (return HTML form) - pub fn apply_captcha_challenge_with_token(&self, site_key: &str, jwt_token: &str) -> String { - self.render_captcha_template(site_key, Some(jwt_token)) - } - - /// Render captcha template based on provider - fn render_captcha_template(&self, site_key: &str, jwt_token: Option<&str>) -> String { - let (frontend_js, frontend_key, callback_attr) = match self.config.provider { - CaptchaProvider::HCaptcha => ( - "https://js.hcaptcha.com/1/api.js", - "h-captcha", - "data-callback=\"captchaCallback\"" - ), - CaptchaProvider::ReCaptcha => ( - "https://www.recaptcha.net/recaptcha/api.js", - "g-recaptcha", - "data-callback=\"captchaCallback\"" - ), - CaptchaProvider::Turnstile => ( - "https://challenges.cloudflare.com/turnstile/v0/api.js", - "cf-turnstile", - "data-callback=\"onTurnstileSuccess\" data-error-callback=\"onTurnstileError\"" - ), - }; - - let jwt_token_input = if let Some(token) = jwt_token { - format!(r#""#, token) - } else { - r#""#.to_string() - }; - - let html_template = format!( - r#" - - - Gen0Sec Captcha - - - - - - -
-
-

Gen0Sec Captcha

-

Please complete the security verification below to continue.

-
- -
-
-
-
- - {} -
- - -
- - - -"#, - frontend_js, - frontend_key, - site_key, - callback_attr, - jwt_token_input - ); - html_template - } - - /// Validate with hCaptcha API - async fn validate_hcaptcha(&self, request: &CaptchaValidationRequest) -> Result { - // Use shared HTTP client with keepalive instead of creating new client - let client = get_global_reqwest_client() - .context("Failed to get global HTTP client")?; - - let mut params = HashMap::new(); - params.insert("response", &request.response_token); - params.insert("secret", &request.secret_key); - params.insert("sitekey", &request.site_key); - params.insert("remoteip", &request.ip_address); - - log::info!("hCaptcha validation request - response_length: {}, remote_ip: {}", - request.response_token.len(), request.ip_address); - - let response = client - .post("https://hcaptcha.com/siteverify") - .form(¶ms) - .send() - .await - .context("Failed to send hCaptcha validation request")?; - - log::info!("hCaptcha validation HTTP response - status: {}", response.status()); - - if !response.status().is_success() { - log::error!("hCaptcha service returned non-success status: {}", response.status()); - return Ok(false); - } - - let validation_response: CaptchaValidationResponse = response - .json() - .await - .context("Failed to parse hCaptcha response")?; - - if !validation_response.success { - if let Some(error_codes) = &validation_response.error_codes { - for error_code in error_codes { - match error_code.as_str() { - "invalid-input-secret" => { - log::error!("hCaptcha secret key is invalid"); - return Ok(false); - } - "invalid-input-response" => { - log::info!("Invalid hCaptcha response from user"); - return Ok(false); - } - "timeout-or-duplicate" => { - log::info!("hCaptcha response expired or duplicate"); - return Ok(false); - } - _ => { - log::warn!("hCaptcha validation failed with error code: {}", error_code); - } - } - } - } - log::info!("hCaptcha validation failed without specific error code"); - return Ok(false); - } - - Ok(true) - } - - /// Validate with reCAPTCHA API - async fn validate_recaptcha(&self, request: &CaptchaValidationRequest) -> Result { - // Use shared HTTP client with keepalive instead of creating new client - let client = get_global_reqwest_client() - .context("Failed to get global HTTP client")?; - - let mut params = HashMap::new(); - params.insert("response", &request.response_token); - params.insert("secret", &request.secret_key); - params.insert("remoteip", &request.ip_address); - - log::info!("reCAPTCHA validation request - response_length: {}, remote_ip: {}", - request.response_token.len(), request.ip_address); - - let response = client - .post("https://www.recaptcha.net/recaptcha/api/siteverify") - .form(¶ms) - .send() - .await - .context("Failed to send reCAPTCHA validation request")?; - - log::info!("reCAPTCHA validation HTTP response - status: {}", response.status()); - - if !response.status().is_success() { - log::error!("reCAPTCHA service returned non-success status: {}", response.status()); - return Ok(false); - } - - let validation_response: CaptchaValidationResponse = response - .json() - .await - .context("Failed to parse reCAPTCHA response")?; - - if !validation_response.success { - if let Some(error_codes) = &validation_response.error_codes { - for error_code in error_codes { - match error_code.as_str() { - "invalid-input-secret" => { - log::error!("reCAPTCHA secret key is invalid"); - return Ok(false); - } - "invalid-input-response" => { - log::info!("Invalid reCAPTCHA response from user"); - return Ok(false); - } - "timeout-or-duplicate" => { - log::info!("reCAPTCHA response expired or duplicate"); - return Ok(false); - } - _ => { - log::warn!("reCAPTCHA validation failed with error code: {}", error_code); - } - } - } - } - log::info!("reCAPTCHA validation failed without specific error code"); - return Ok(false); - } - - Ok(true) - } - - /// Validate with Cloudflare Turnstile API - async fn validate_turnstile(&self, request: &CaptchaValidationRequest) -> Result { - // Use shared HTTP client with keepalive instead of creating new client - let client = get_global_reqwest_client() - .context("Failed to get global HTTP client")?; - - let mut params = HashMap::new(); - params.insert("response", &request.response_token); - params.insert("secret", &request.secret_key); - params.insert("remoteip", &request.ip_address); - - log::info!("Turnstile validation request - response_length: {}, remote_ip: {}", - request.response_token.len(), request.ip_address); - - let response = client - .post("https://challenges.cloudflare.com/turnstile/v0/siteverify") - .form(¶ms) - .send() - .await - .context("Failed to send Turnstile validation request")?; - - log::info!("Turnstile validation HTTP response - status: {}", response.status()); - - if !response.status().is_success() { - log::error!("Turnstile service returned non-success status: {}", response.status()); - return Ok(false); - } - - let validation_response: CaptchaValidationResponse = response - .json() - .await - .context("Failed to parse Turnstile response")?; - - if !validation_response.success { - if let Some(error_codes) = &validation_response.error_codes { - for error_code in error_codes { - match error_code.as_str() { - "invalid-input-secret" => { - log::error!("Turnstile secret key is invalid"); - return Ok(false); - } - "invalid-input-response" => { - log::info!("Invalid Turnstile response from user"); - return Ok(false); - } - "timeout-or-duplicate" => { - log::info!("Turnstile response expired or duplicate"); - return Ok(false); - } - _ => { - log::warn!("Turnstile validation failed with error code: {}", error_code); - } - } - } - } - log::info!("Turnstile validation failed without specific error code"); - return Ok(false); - } - - Ok(true) - } - - /// Get the captcha backend response key name for the current provider - pub fn get_captcha_backend_key(&self) -> &'static str { - match self.config.provider { - CaptchaProvider::HCaptcha => "h-captcha-response", - CaptchaProvider::ReCaptcha => "g-recaptcha-response", - CaptchaProvider::Turnstile => "cf-turnstile-response", - } - } - - /// Get validation result from cache - async fn get_validation_cache(&self, key: &str) -> Option { - let cache = self.validation_cache.read().await; - cache.get(key).cloned() - } - - /// Set validation result in cache - async fn set_validation_cache(&self, key: &str, is_valid: bool) { - let mut cache = self.validation_cache.write().await; - cache.insert( - key.to_string(), - CachedCaptchaResult { - is_valid, - expires_at: Instant::now() + Duration::from_secs(self.config.validation_cache_ttl_seconds), - }, - ); - } - - /// Remove validation result from cache - async fn remove_validation_cache(&self, key: &str) { - let mut cache = self.validation_cache.write().await; - cache.remove(key); - } - - /// Clean up expired cache entries - pub async fn cleanup_cache(&self) { - let mut cache = self.validation_cache.write().await; - let now = Instant::now(); - cache.retain(|_, cached| cached.expires_at > now); - - // Also clean up expired validated tokens - let mut validated_tokens = self.validated_tokens.write().await; - validated_tokens.retain(|_, expiration| *expiration > now); - } -} - -/// Global captcha client instance -static CAPTCHA_CLIENT: OnceCell> = OnceCell::const_new(); - -/// Initialize the global captcha client -pub async fn init_captcha_client( - config: CaptchaConfig, -) -> Result<()> { - let client = Arc::new(CaptchaClient::new(config)); - - CAPTCHA_CLIENT.set(client) - .map_err(|_| anyhow::anyhow!("Failed to initialize captcha client"))?; - - Ok(()) -} - -/// Validate captcha response -pub async fn validate_captcha_response( - response_token: String, - ip_address: String, - user_agent: Option, -) -> Result { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - let request = CaptchaValidationRequest { - response_token, - ip_address, - user_agent: user_agent, - site_key: client.config.site_key.clone(), - secret_key: client.config.secret_key.clone(), - provider: client.config.provider.clone(), - }; - - client.validate_captcha(request).await -} - -/// Generate captcha token -pub async fn generate_captcha_token( - ip_address: String, - user_agent: String, - ja4_fingerprint: Option, -) -> Result { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - client.generate_token(ip_address, user_agent, ja4_fingerprint).await -} - -/// Validate captcha token -pub async fn validate_captcha_token( - token: &str, - ip_address: &str, - user_agent: &str, -) -> Result { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - client.validate_token(token, ip_address, user_agent).await -} - -/// Apply captcha challenge -pub fn apply_captcha_challenge() -> Result { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - Ok(client.apply_captcha_challenge(&client.config.site_key)) -} - -/// Apply captcha challenge with JWT token -pub fn apply_captcha_challenge_with_token(jwt_token: &str) -> Result { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - Ok(client.apply_captcha_challenge_with_token(&client.config.site_key, jwt_token)) -} - -/// Get the captcha backend response key name -pub fn get_captcha_backend_key() -> Result<&'static str> { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - Ok(client.get_captcha_backend_key()) -} - -/// Mark a JWT token as validated after successful captcha completion -pub async fn mark_captcha_token_validated(token: &str) -> Result<()> { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - client.mark_token_validated(token).await -} - -/// Revoke a JWT token -pub async fn revoke_captcha_token(token: &str) -> Result<()> { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - client.revoke_token(token).await -} - -/// Validate captcha response and mark token as validated -pub async fn validate_and_mark_captcha( - response_token: String, - jwt_token: String, - ip_address: String, - user_agent: Option, -) -> Result { - log::info!("validate_and_mark_captcha called for IP: {}, response_token length: {}, jwt_token length: {}", - ip_address, response_token.len(), jwt_token.len()); - - // First validate the captcha response - let is_valid = validate_captcha_response(response_token, ip_address.clone(), user_agent.clone()).await?; - - log::info!("Captcha validation result: {}", is_valid); - - if is_valid { - // Only try to mark JWT token as validated if it's not empty - if !jwt_token.is_empty() { - if let Err(e) = mark_captcha_token_validated(&jwt_token).await { - log::warn!("Failed to mark JWT token as validated: {}", e); - // Don't return false here - captcha validation succeeded - } else { - log::info!("Captcha validated and JWT token marked as validated for IP: {}", ip_address); - } - } else { - log::info!("Captcha validated successfully for IP: {} (no JWT token to mark)", ip_address); - } - } else { - log::warn!("Captcha validation failed for IP: {}", ip_address); - } - - Ok(is_valid) -} - -/// Start periodic cache cleanup task -pub async fn start_cache_cleanup_task() { - tokio::spawn(async { - let mut interval = tokio::time::interval(Duration::from_secs(60)); - loop { - interval.tick().await; - if let Some(client) = CAPTCHA_CLIENT.get() { - client.cleanup_cache().await; - } - } - }); -} +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use anyhow::{Context, Result}; +use chrono::Utc; +use redis::AsyncCommands; +use serde::{Deserialize, Serialize}; +use tokio::sync::{RwLock, OnceCell}; +use jsonwebtoken::{encode, decode, Header, Algorithm, Validation, EncodingKey, DecodingKey}; +use uuid::Uuid; + +use crate::redis::RedisManager; +use crate::http_client::get_global_reqwest_client; + +/// Captcha provider types supported by Gen0Sec +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, clap::ValueEnum)] +pub enum CaptchaProvider { + #[serde(rename = "hcaptcha")] + HCaptcha, + #[serde(rename = "recaptcha")] + ReCaptcha, + #[serde(rename = "turnstile")] + Turnstile, +} + +impl Default for CaptchaProvider { + fn default() -> Self { + CaptchaProvider::HCaptcha + } +} + +impl std::str::FromStr for CaptchaProvider { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "hcaptcha" => Ok(CaptchaProvider::HCaptcha), + "recaptcha" => Ok(CaptchaProvider::ReCaptcha), + "turnstile" => Ok(CaptchaProvider::Turnstile), + _ => Err(anyhow::anyhow!("Invalid captcha provider: {}", s)), + } + } +} + +/// Captcha validation request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CaptchaValidationRequest { + pub response_token: String, + pub ip_address: String, + pub user_agent: Option, + pub site_key: String, + pub secret_key: String, + pub provider: CaptchaProvider, +} + +/// Captcha validation response +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CaptchaValidationResponse { + pub success: bool, + pub error_codes: Option>, + pub challenge_ts: Option, + pub hostname: Option, + pub score: Option, + pub action: Option, +} + +/// JWT Claims for captcha tokens +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CaptchaClaims { + /// Standard JWT claims + pub sub: String, // Subject (user identifier) + pub iss: String, // Issuer + pub aud: String, // Audience + pub exp: i64, // Expiration time + pub iat: i64, // Issued at + pub jti: String, // JWT ID (unique identifier) + + /// Custom captcha claims + pub ip_address: String, + pub user_agent: String, + pub ja4_fingerprint: Option, + pub captcha_provider: String, + pub captcha_validated: bool, +} + +/// Captcha token with JWT-based security +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CaptchaToken { + pub token: String, + pub claims: CaptchaClaims, +} + +/// Cached captcha validation result +#[derive(Debug, Clone)] +pub struct CachedCaptchaResult { + pub is_valid: bool, + pub expires_at: Instant, +} + +/// Captcha action configuration +#[derive(Debug, Clone)] +pub struct CaptchaConfig { + pub site_key: String, + pub secret_key: String, + pub jwt_secret: String, + pub provider: CaptchaProvider, + pub token_ttl_seconds: u64, + pub validation_cache_ttl_seconds: u64, +} + +/// Captcha client for validation and token management +pub struct CaptchaClient { + config: CaptchaConfig, + validation_cache: Arc>>, + validated_tokens: Arc>>, // JTI -> expiration time +} + +impl CaptchaClient { + pub fn new( + config: CaptchaConfig, + ) -> Self { + Self { + config, + validation_cache: Arc::new(RwLock::new(HashMap::new())), + validated_tokens: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Validate a captcha response token + pub async fn validate_captcha(&self, request: CaptchaValidationRequest) -> Result { + log::info!("Starting captcha validation for IP: {}, provider: {:?}", + request.ip_address, self.config.provider); + + // Check if captcha response is provided + if request.response_token.is_empty() { + log::warn!("No captcha response provided for IP: {}", request.ip_address); + return Ok(false); + } + + log::debug!("Captcha response token length: {}", request.response_token.len()); + + // Check validation cache first + let cache_key = format!("{}:{}", request.response_token, request.ip_address); + if let Some(cached) = self.get_validation_cache(&cache_key).await { + if cached.expires_at > Instant::now() { + log::debug!("Captcha validation for {} found in cache", request.ip_address); + return Ok(cached.is_valid); + } else { + self.remove_validation_cache(&cache_key).await; + } + } + + // Validate with provider API + let is_valid = match self.config.provider { + CaptchaProvider::HCaptcha => self.validate_hcaptcha(&request).await?, + CaptchaProvider::ReCaptcha => self.validate_recaptcha(&request).await?, + CaptchaProvider::Turnstile => self.validate_turnstile(&request).await?, + }; + + log::info!("Captcha validation result for IP {}: {}", request.ip_address, is_valid); + + // Cache the result + self.set_validation_cache(&cache_key, is_valid).await; + + Ok(is_valid) + } + + /// Generate a secure JWT captcha token + pub async fn generate_token( + &self, + ip_address: String, + user_agent: String, + ja4_fingerprint: Option, + ) -> Result { + let now = Utc::now(); + let exp = now + chrono::Duration::seconds(self.config.token_ttl_seconds as i64); + let jti = Uuid::new_v4().to_string(); + + let claims = CaptchaClaims { + sub: format!("captcha:{}", ip_address), + iss: "arxignis-synapse".to_string(), + aud: "captcha-validation".to_string(), + exp: exp.timestamp(), + iat: now.timestamp(), + jti: jti.clone(), + ip_address: ip_address.clone(), + user_agent: user_agent.clone(), + ja4_fingerprint, + captcha_provider: format!("{:?}", self.config.provider), + captcha_validated: false, + }; + + let header = Header::new(Algorithm::HS256); + let encoding_key = EncodingKey::from_secret(self.config.jwt_secret.as_bytes()); + + let token = encode(&header, &claims, &encoding_key) + .context("Failed to encode JWT token")?; + + let captcha_token = CaptchaToken { + token: token.clone(), + claims: claims.clone(), + }; + + // Store token in Redis for validation (optional, JWT is self-contained) + if let Ok(redis_manager) = RedisManager::get() { + let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), jti); + let mut redis = redis_manager.get_connection(); + let token_data = serde_json::to_string(&captcha_token) + .context("Failed to serialize captcha token")?; + + let _: () = redis + .set_ex(&key, token_data, self.config.token_ttl_seconds) + .await + .context("Failed to store captcha token in Redis")?; + } + + Ok(captcha_token) + } + + /// Validate a JWT captcha token + pub async fn validate_token(&self, token: &str, ip_address: &str, user_agent: &str) -> Result { + let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); + let mut validation = Validation::new(Algorithm::HS256); + validation.set_audience(&["captcha-validation"]); + + match decode::(token, &decoding_key, &validation) { + Ok(token_data) => { + let claims = token_data.claims; + + // Check if token is expired (JWT handles this automatically, but double-check) + let now = Utc::now().timestamp(); + if claims.exp < now { + log::debug!("JWT token expired"); + return Ok(false); + } + + // Verify IP and User-Agent binding + if claims.ip_address != ip_address || claims.user_agent != user_agent { + log::warn!("JWT token validation failed: IP or User-Agent mismatch"); + return Ok(false); + } + + // Check Redis first for updated token state + let mut captcha_validated = claims.captcha_validated; + log::debug!("Initial JWT token captcha_validated: {}", captcha_validated); + + // Check in-memory cache first (faster) + { + let validated_tokens = self.validated_tokens.read().await; + if let Some(expiration) = validated_tokens.get(&claims.jti) { + if *expiration > Instant::now() { + captcha_validated = true; + log::debug!("Found validated token JTI {} in memory cache", claims.jti); + } else { + log::debug!("Token JTI {} expired in memory cache", claims.jti); + } + } + } + + // If not found in memory cache, check Redis + if !captcha_validated { + if let Ok(redis_manager) = RedisManager::get() { + let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), claims.jti); + log::debug!("Looking up token in Redis with key: {}", key); + + let mut redis = redis_manager.get_connection(); + match redis.get::<_, String>(&key).await { + Ok(token_data_str) => { + log::debug!("Found token data in Redis: {}", token_data_str); + if let Ok(updated_token) = serde_json::from_str::(&token_data_str) { + captcha_validated = updated_token.claims.captcha_validated; + log::debug!("Updated captcha_validated from Redis: {}", captcha_validated); + + // Update memory cache if found in Redis + if captcha_validated { + let expiration = Instant::now() + Duration::from_secs(self.config.token_ttl_seconds); + let mut validated_tokens = self.validated_tokens.write().await; + validated_tokens.insert(claims.jti.clone(), expiration); + } + } else { + log::warn!("Failed to parse token data from Redis"); + } + } + Err(e) => { + log::debug!("Redis token lookup failed for JTI {}: {}", claims.jti, e); + } + } + } else { + log::debug!("Redis manager not available"); + } + } + + // Check if captcha was validated (either from JWT or Redis) + if !captcha_validated { + log::debug!("JWT token not validated for captcha"); + return Ok(false); + } + + // Optional: Check Redis blacklist for revoked tokens + if let Ok(redis_manager) = RedisManager::get() { + let blacklist_key = format!("{}:captcha_blacklist:{}", redis_manager.create_namespace("captcha"), claims.jti); + let mut redis = redis_manager.get_connection(); + match redis.exists::<_, bool>(&blacklist_key).await { + Ok(true) => { + log::debug!("JWT token {} is blacklisted", claims.jti); + return Ok(false); + } + Ok(false) => { + // Token not blacklisted, continue validation + } + Err(e) => { + log::warn!("Redis blacklist check error for JWT {}: {}", claims.jti, e); + // Continue validation despite Redis error + } + } + } + + Ok(true) + } + Err(e) => { + log::warn!("JWT token validation failed: {}", e); + Ok(false) + } + } + } + + /// Mark a JWT token as validated after successful captcha completion + pub async fn mark_token_validated(&self, token: &str) -> Result<()> { + let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); + let mut validation = Validation::new(Algorithm::HS256); + validation.set_audience(&["captcha-validation"]); + + match decode::(token, &decoding_key, &validation) { + Ok(token_data) => { + let claims = token_data.claims; + + // Store the JTI as validated in memory cache + let expiration = Instant::now() + Duration::from_secs(self.config.token_ttl_seconds); + { + let mut validated_tokens = self.validated_tokens.write().await; + validated_tokens.insert(claims.jti.clone(), expiration); + log::debug!("Marked token JTI {} as validated, expires at {:?}", claims.jti, expiration); + } + + // Also update Redis cache if available (for persistence across restarts) + if let Ok(redis_manager) = RedisManager::get() { + let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), claims.jti); + log::debug!("Storing updated token in Redis with key: {}", key); + + let mut redis = redis_manager.get_connection(); + let mut updated_claims = claims.clone(); + updated_claims.captcha_validated = true; + + let updated_captcha_token = CaptchaToken { + token: token.to_string(), + claims: updated_claims, + }; + let token_data = serde_json::to_string(&updated_captcha_token) + .context("Failed to serialize updated captcha token")?; + + log::debug!("Token data to store: {}", token_data); + + let _: () = redis + .set_ex(&key, token_data, self.config.token_ttl_seconds) + .await + .context("Failed to update captcha token in Redis")?; + + log::debug!("Successfully stored updated token in Redis"); + } else { + log::debug!("Redis manager not available for token storage"); + } + + Ok(()) + } + Err(e) => { + log::warn!("Failed to decode JWT token for validation marking: {}", e); + Err(anyhow::anyhow!("Invalid JWT token: {}", e)) + } + } + } + + /// Revoke a JWT token by adding it to blacklist + pub async fn revoke_token(&self, token: &str) -> Result<()> { + let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); + let mut validation = Validation::new(Algorithm::HS256); + validation.set_audience(&["captcha-validation"]); + + match decode::(token, &decoding_key, &validation) { + Ok(token_data) => { + let claims = token_data.claims; + + // Add to Redis blacklist + if let Ok(redis_manager) = RedisManager::get() { + let blacklist_key = format!("{}:captcha_blacklist:{}", redis_manager.create_namespace("captcha"), claims.jti); + let mut redis = redis_manager.get_connection(); + let _: () = redis + .set_ex(&blacklist_key, "revoked", self.config.token_ttl_seconds) + .await + .context("Failed to add token to blacklist")?; + } + + Ok(()) + } + Err(e) => { + log::warn!("Failed to decode JWT token for revocation: {}", e); + Err(anyhow::anyhow!("Invalid JWT token: {}", e)) + } + } + } + + /// Apply captcha challenge (return HTML form) + pub fn apply_captcha_challenge(&self, site_key: &str) -> String { + self.render_captcha_template(site_key, None) + } + + /// Apply captcha challenge with JWT token (return HTML form) + pub fn apply_captcha_challenge_with_token(&self, site_key: &str, jwt_token: &str) -> String { + self.render_captcha_template(site_key, Some(jwt_token)) + } + + /// Render captcha template based on provider + fn render_captcha_template(&self, site_key: &str, jwt_token: Option<&str>) -> String { + let (frontend_js, frontend_key, callback_attr) = match self.config.provider { + CaptchaProvider::HCaptcha => ( + "https://js.hcaptcha.com/1/api.js", + "h-captcha", + "data-callback=\"captchaCallback\"" + ), + CaptchaProvider::ReCaptcha => ( + "https://www.recaptcha.net/recaptcha/api.js", + "g-recaptcha", + "data-callback=\"captchaCallback\"" + ), + CaptchaProvider::Turnstile => ( + "https://challenges.cloudflare.com/turnstile/v0/api.js", + "cf-turnstile", + "data-callback=\"onTurnstileSuccess\" data-error-callback=\"onTurnstileError\"" + ), + }; + + let jwt_token_input = if let Some(token) = jwt_token { + format!(r#""#, token) + } else { + r#""#.to_string() + }; + + let html_template = format!( + r#" + + + Gen0Sec Captcha + + + + + + +
+
+

Gen0Sec Captcha

+

Please complete the security verification below to continue.

+
+ +
+
+
+
+ + {} +
+ + +
+ + + +"#, + frontend_js, + frontend_key, + site_key, + callback_attr, + jwt_token_input + ); + html_template + } + + /// Validate with hCaptcha API + async fn validate_hcaptcha(&self, request: &CaptchaValidationRequest) -> Result { + // Use shared HTTP client with keepalive instead of creating new client + let client = get_global_reqwest_client() + .context("Failed to get global HTTP client")?; + + let mut params = HashMap::new(); + params.insert("response", &request.response_token); + params.insert("secret", &request.secret_key); + params.insert("sitekey", &request.site_key); + params.insert("remoteip", &request.ip_address); + + log::info!("hCaptcha validation request - response_length: {}, remote_ip: {}", + request.response_token.len(), request.ip_address); + + let response = client + .post("https://hcaptcha.com/siteverify") + .form(¶ms) + .send() + .await + .context("Failed to send hCaptcha validation request")?; + + log::info!("hCaptcha validation HTTP response - status: {}", response.status()); + + if !response.status().is_success() { + log::error!("hCaptcha service returned non-success status: {}", response.status()); + return Ok(false); + } + + let validation_response: CaptchaValidationResponse = response + .json() + .await + .context("Failed to parse hCaptcha response")?; + + if !validation_response.success { + if let Some(error_codes) = &validation_response.error_codes { + for error_code in error_codes { + match error_code.as_str() { + "invalid-input-secret" => { + log::error!("hCaptcha secret key is invalid"); + return Ok(false); + } + "invalid-input-response" => { + log::info!("Invalid hCaptcha response from user"); + return Ok(false); + } + "timeout-or-duplicate" => { + log::info!("hCaptcha response expired or duplicate"); + return Ok(false); + } + _ => { + log::warn!("hCaptcha validation failed with error code: {}", error_code); + } + } + } + } + log::info!("hCaptcha validation failed without specific error code"); + return Ok(false); + } + + Ok(true) + } + + /// Validate with reCAPTCHA API + async fn validate_recaptcha(&self, request: &CaptchaValidationRequest) -> Result { + // Use shared HTTP client with keepalive instead of creating new client + let client = get_global_reqwest_client() + .context("Failed to get global HTTP client")?; + + let mut params = HashMap::new(); + params.insert("response", &request.response_token); + params.insert("secret", &request.secret_key); + params.insert("remoteip", &request.ip_address); + + log::info!("reCAPTCHA validation request - response_length: {}, remote_ip: {}", + request.response_token.len(), request.ip_address); + + let response = client + .post("https://www.recaptcha.net/recaptcha/api/siteverify") + .form(¶ms) + .send() + .await + .context("Failed to send reCAPTCHA validation request")?; + + log::info!("reCAPTCHA validation HTTP response - status: {}", response.status()); + + if !response.status().is_success() { + log::error!("reCAPTCHA service returned non-success status: {}", response.status()); + return Ok(false); + } + + let validation_response: CaptchaValidationResponse = response + .json() + .await + .context("Failed to parse reCAPTCHA response")?; + + if !validation_response.success { + if let Some(error_codes) = &validation_response.error_codes { + for error_code in error_codes { + match error_code.as_str() { + "invalid-input-secret" => { + log::error!("reCAPTCHA secret key is invalid"); + return Ok(false); + } + "invalid-input-response" => { + log::info!("Invalid reCAPTCHA response from user"); + return Ok(false); + } + "timeout-or-duplicate" => { + log::info!("reCAPTCHA response expired or duplicate"); + return Ok(false); + } + _ => { + log::warn!("reCAPTCHA validation failed with error code: {}", error_code); + } + } + } + } + log::info!("reCAPTCHA validation failed without specific error code"); + return Ok(false); + } + + Ok(true) + } + + /// Validate with Cloudflare Turnstile API + async fn validate_turnstile(&self, request: &CaptchaValidationRequest) -> Result { + // Use shared HTTP client with keepalive instead of creating new client + let client = get_global_reqwest_client() + .context("Failed to get global HTTP client")?; + + let mut params = HashMap::new(); + params.insert("response", &request.response_token); + params.insert("secret", &request.secret_key); + params.insert("remoteip", &request.ip_address); + + log::info!("Turnstile validation request - response_length: {}, remote_ip: {}", + request.response_token.len(), request.ip_address); + + let response = client + .post("https://challenges.cloudflare.com/turnstile/v0/siteverify") + .form(¶ms) + .send() + .await + .context("Failed to send Turnstile validation request")?; + + log::info!("Turnstile validation HTTP response - status: {}", response.status()); + + if !response.status().is_success() { + log::error!("Turnstile service returned non-success status: {}", response.status()); + return Ok(false); + } + + let validation_response: CaptchaValidationResponse = response + .json() + .await + .context("Failed to parse Turnstile response")?; + + if !validation_response.success { + if let Some(error_codes) = &validation_response.error_codes { + for error_code in error_codes { + match error_code.as_str() { + "invalid-input-secret" => { + log::error!("Turnstile secret key is invalid"); + return Ok(false); + } + "invalid-input-response" => { + log::info!("Invalid Turnstile response from user"); + return Ok(false); + } + "timeout-or-duplicate" => { + log::info!("Turnstile response expired or duplicate"); + return Ok(false); + } + _ => { + log::warn!("Turnstile validation failed with error code: {}", error_code); + } + } + } + } + log::info!("Turnstile validation failed without specific error code"); + return Ok(false); + } + + Ok(true) + } + + /// Get the captcha backend response key name for the current provider + pub fn get_captcha_backend_key(&self) -> &'static str { + match self.config.provider { + CaptchaProvider::HCaptcha => "h-captcha-response", + CaptchaProvider::ReCaptcha => "g-recaptcha-response", + CaptchaProvider::Turnstile => "cf-turnstile-response", + } + } + + /// Get validation result from cache + async fn get_validation_cache(&self, key: &str) -> Option { + let cache = self.validation_cache.read().await; + cache.get(key).cloned() + } + + /// Set validation result in cache + async fn set_validation_cache(&self, key: &str, is_valid: bool) { + let mut cache = self.validation_cache.write().await; + cache.insert( + key.to_string(), + CachedCaptchaResult { + is_valid, + expires_at: Instant::now() + Duration::from_secs(self.config.validation_cache_ttl_seconds), + }, + ); + } + + /// Remove validation result from cache + async fn remove_validation_cache(&self, key: &str) { + let mut cache = self.validation_cache.write().await; + cache.remove(key); + } + + /// Clean up expired cache entries + pub async fn cleanup_cache(&self) { + let mut cache = self.validation_cache.write().await; + let now = Instant::now(); + cache.retain(|_, cached| cached.expires_at > now); + + // Also clean up expired validated tokens + let mut validated_tokens = self.validated_tokens.write().await; + validated_tokens.retain(|_, expiration| *expiration > now); + } +} + +/// Global captcha client instance +static CAPTCHA_CLIENT: OnceCell> = OnceCell::const_new(); + +/// Initialize the global captcha client +pub async fn init_captcha_client( + config: CaptchaConfig, +) -> Result<()> { + let client = Arc::new(CaptchaClient::new(config)); + + CAPTCHA_CLIENT.set(client) + .map_err(|_| anyhow::anyhow!("Failed to initialize captcha client"))?; + + Ok(()) +} + +/// Validate captcha response +pub async fn validate_captcha_response( + response_token: String, + ip_address: String, + user_agent: Option, +) -> Result { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + let request = CaptchaValidationRequest { + response_token, + ip_address, + user_agent: user_agent, + site_key: client.config.site_key.clone(), + secret_key: client.config.secret_key.clone(), + provider: client.config.provider.clone(), + }; + + client.validate_captcha(request).await +} + +/// Generate captcha token +pub async fn generate_captcha_token( + ip_address: String, + user_agent: String, + ja4_fingerprint: Option, +) -> Result { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + client.generate_token(ip_address, user_agent, ja4_fingerprint).await +} + +/// Validate captcha token +pub async fn validate_captcha_token( + token: &str, + ip_address: &str, + user_agent: &str, +) -> Result { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + client.validate_token(token, ip_address, user_agent).await +} + +/// Apply captcha challenge +pub fn apply_captcha_challenge() -> Result { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + Ok(client.apply_captcha_challenge(&client.config.site_key)) +} + +/// Apply captcha challenge with JWT token +pub fn apply_captcha_challenge_with_token(jwt_token: &str) -> Result { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + Ok(client.apply_captcha_challenge_with_token(&client.config.site_key, jwt_token)) +} + +/// Get the captcha backend response key name +pub fn get_captcha_backend_key() -> Result<&'static str> { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + Ok(client.get_captcha_backend_key()) +} + +/// Mark a JWT token as validated after successful captcha completion +pub async fn mark_captcha_token_validated(token: &str) -> Result<()> { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + client.mark_token_validated(token).await +} + +/// Revoke a JWT token +pub async fn revoke_captcha_token(token: &str) -> Result<()> { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + client.revoke_token(token).await +} + +/// Validate captcha response and mark token as validated +pub async fn validate_and_mark_captcha( + response_token: String, + jwt_token: String, + ip_address: String, + user_agent: Option, +) -> Result { + log::info!("validate_and_mark_captcha called for IP: {}, response_token length: {}, jwt_token length: {}", + ip_address, response_token.len(), jwt_token.len()); + + // First validate the captcha response + let is_valid = validate_captcha_response(response_token, ip_address.clone(), user_agent.clone()).await?; + + log::info!("Captcha validation result: {}", is_valid); + + if is_valid { + // Only try to mark JWT token as validated if it's not empty + if !jwt_token.is_empty() { + if let Err(e) = mark_captcha_token_validated(&jwt_token).await { + log::warn!("Failed to mark JWT token as validated: {}", e); + // Don't return false here - captcha validation succeeded + } else { + log::info!("Captcha validated and JWT token marked as validated for IP: {}", ip_address); + } + } else { + log::info!("Captcha validated successfully for IP: {} (no JWT token to mark)", ip_address); + } + } else { + log::warn!("Captcha validation failed for IP: {}", ip_address); + } + + Ok(is_valid) +} + +/// Start periodic cache cleanup task +pub async fn start_cache_cleanup_task() { + tokio::spawn(async { + let mut interval = tokio::time::interval(Duration::from_secs(60)); + loop { + interval.tick().await; + if let Some(client) = CAPTCHA_CLIENT.get() { + client.cleanup_cache().await; + } + } + }); +} diff --git a/src/actions/mod.rs b/src/actions/mod.rs index 9f4c525..a966519 100644 --- a/src/actions/mod.rs +++ b/src/actions/mod.rs @@ -1 +1 @@ -pub mod captcha; +pub mod captcha; diff --git a/src/agent_status.rs b/src/agent_status.rs index 68fb1f3..4f2de29 100644 --- a/src/agent_status.rs +++ b/src/agent_status.rs @@ -1,73 +1,73 @@ -use std::collections::HashMap; - -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AgentStatusEvent { - pub event_type: String, - pub schema_version: String, - pub timestamp: DateTime, - pub agent_id: String, - pub agent_name: String, - pub hostname: String, - pub version: String, - pub mode: String, - pub status: String, - pub pid: u32, - pub started_at: DateTime, - pub last_seen: DateTime, - pub uptime_secs: u64, - pub tags: Vec, - pub capabilities: Vec, - pub interfaces: Vec, - pub ip_addresses: Vec, - pub metadata: HashMap, -} - -#[derive(Debug, Clone)] -pub struct AgentStatusIdentity { - pub agent_id: String, - pub agent_name: String, - pub hostname: String, - pub version: String, - pub mode: String, - pub tags: Vec, - pub capabilities: Vec, - pub interfaces: Vec, - pub ip_addresses: Vec, - pub metadata: HashMap, - pub started_at: DateTime, -} - -impl AgentStatusIdentity { - pub fn to_event(&self, status: &str) -> AgentStatusEvent { - let now = Utc::now(); - let uptime_secs = now - .signed_duration_since(self.started_at) - .num_seconds() - .max(0) as u64; - +use std::collections::HashMap; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AgentStatusEvent { + pub event_type: String, + pub schema_version: String, + pub timestamp: DateTime, + pub agent_id: String, + pub agent_name: String, + pub hostname: String, + pub version: String, + pub mode: String, + pub status: String, + pub pid: u32, + pub started_at: DateTime, + pub last_seen: DateTime, + pub uptime_secs: u64, + pub tags: Vec, + pub capabilities: Vec, + pub interfaces: Vec, + pub ip_addresses: Vec, + pub metadata: HashMap, +} + +#[derive(Debug, Clone)] +pub struct AgentStatusIdentity { + pub agent_id: String, + pub agent_name: String, + pub hostname: String, + pub version: String, + pub mode: String, + pub tags: Vec, + pub capabilities: Vec, + pub interfaces: Vec, + pub ip_addresses: Vec, + pub metadata: HashMap, + pub started_at: DateTime, +} + +impl AgentStatusIdentity { + pub fn to_event(&self, status: &str) -> AgentStatusEvent { + let now = Utc::now(); + let uptime_secs = now + .signed_duration_since(self.started_at) + .num_seconds() + .max(0) as u64; + AgentStatusEvent { event_type: "agent_status".to_string(), - schema_version: "1.0.0".to_string(), + schema_version: "1.0".to_string(), timestamp: now, agent_id: self.agent_id.clone(), - agent_name: self.agent_name.clone(), - hostname: self.hostname.clone(), - version: self.version.clone(), - mode: self.mode.clone(), - status: status.to_string(), - pid: std::process::id(), - started_at: self.started_at, - last_seen: now, - uptime_secs, - tags: self.tags.clone(), - capabilities: self.capabilities.clone(), - interfaces: self.interfaces.clone(), - ip_addresses: self.ip_addresses.clone(), - metadata: self.metadata.clone(), - } - } -} - + agent_name: self.agent_name.clone(), + hostname: self.hostname.clone(), + version: self.version.clone(), + mode: self.mode.clone(), + status: status.to_string(), + pid: std::process::id(), + started_at: self.started_at, + last_seen: now, + uptime_secs, + tags: self.tags.clone(), + capabilities: self.capabilities.clone(), + interfaces: self.interfaces.clone(), + ip_addresses: self.ip_addresses.clone(), + metadata: self.metadata.clone(), + } + } +} + diff --git a/src/app_state.rs b/src/app_state.rs index 04361fb..365c66d 100644 --- a/src/app_state.rs +++ b/src/app_state.rs @@ -1,16 +1,16 @@ -use crate::{bpf}; -use std::sync::{Arc, Mutex}; -use crate::bpf_stats::BpfStatsCollector; -use crate::firewall::{FirewallBackend, NftablesFirewall, IptablesFirewall}; -use crate::utils::tcp_fingerprint::TcpFingerprintCollector; - -#[derive(Clone)] -pub struct AppState { - pub skels: Vec>>, - pub ifindices: Vec, - pub bpf_stats_collector: BpfStatsCollector, - pub tcp_fingerprint_collector: TcpFingerprintCollector, - pub firewall_backend: FirewallBackend, - pub nftables_firewall: Option>>, - pub iptables_firewall: Option>>, -} +use crate::{bpf}; +use std::sync::{Arc, Mutex}; +use crate::bpf_stats::BpfStatsCollector; +use crate::firewall::{FirewallBackend, NftablesFirewall, IptablesFirewall}; +use crate::utils::tcp_fingerprint::TcpFingerprintCollector; + +#[derive(Clone)] +pub struct AppState { + pub skels: Vec>>, + pub ifindices: Vec, + pub bpf_stats_collector: BpfStatsCollector, + pub tcp_fingerprint_collector: TcpFingerprintCollector, + pub firewall_backend: FirewallBackend, + pub nftables_firewall: Option>>, + pub iptables_firewall: Option>>, +} diff --git a/src/authcheck.rs b/src/authcheck.rs index 4bbcda2..3f66fca 100644 --- a/src/authcheck.rs +++ b/src/authcheck.rs @@ -1,58 +1,58 @@ -use anyhow::{Context, Result}; -use serde::{Deserialize, Serialize}; -use crate::http_client::get_global_reqwest_client; - -#[derive(Debug, Serialize, Deserialize)] -pub struct AuthCheckResponse { - pub success: bool, - pub message: Option, -} - -/// Validates the API key by calling the /authcheck endpoint -pub async fn validate_api_key(base_url: &str, api_key: &str) -> Result<()> { - if base_url.is_empty() || api_key.is_empty() { - return Err(anyhow::anyhow!("Base URL and API key must be provided")); - } - - // Use shared HTTP client with keepalive instead of creating new client - let client = get_global_reqwest_client() - .context("Failed to get global HTTP client")?; - - let url = format!("{}/authcheck", base_url); - - let response = client - .get(&url) - .header("Authorization", format!("Bearer {}", api_key)) - .send() - .await - .context("Failed to send authcheck request")?; - - match response.status() { - reqwest::StatusCode::OK => { - let auth_response: AuthCheckResponse = response - .json() - .await - .context("Failed to parse authcheck response")?; - - if auth_response.success { - Ok(()) - } else { - let error_msg = auth_response.message.unwrap_or_else(|| "Unknown error".to_string()); - Err(anyhow::anyhow!("API key validation failed: {}", error_msg)) - } - } - reqwest::StatusCode::UNAUTHORIZED => { - Err(anyhow::anyhow!("API key validation failed: Unauthorized (401)")) - } - reqwest::StatusCode::FORBIDDEN => { - Err(anyhow::anyhow!("API key validation failed: Forbidden (403)")) - } - status => { - Err(anyhow::anyhow!( - "API key validation failed with status: {} - {}", - status.as_u16(), - status.canonical_reason().unwrap_or("Unknown") - )) - } - } -} +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use crate::http_client::get_global_reqwest_client; + +#[derive(Debug, Serialize, Deserialize)] +pub struct AuthCheckResponse { + pub success: bool, + pub message: Option, +} + +/// Validates the API key by calling the /authcheck endpoint +pub async fn validate_api_key(base_url: &str, api_key: &str) -> Result<()> { + if base_url.is_empty() || api_key.is_empty() { + return Err(anyhow::anyhow!("Base URL and API key must be provided")); + } + + // Use shared HTTP client with keepalive instead of creating new client + let client = get_global_reqwest_client() + .context("Failed to get global HTTP client")?; + + let url = format!("{}/authcheck", base_url); + + let response = client + .get(&url) + .header("Authorization", format!("Bearer {}", api_key)) + .send() + .await + .context("Failed to send authcheck request")?; + + match response.status() { + reqwest::StatusCode::OK => { + let auth_response: AuthCheckResponse = response + .json() + .await + .context("Failed to parse authcheck response")?; + + if auth_response.success { + Ok(()) + } else { + let error_msg = auth_response.message.unwrap_or_else(|| "Unknown error".to_string()); + Err(anyhow::anyhow!("API key validation failed: {}", error_msg)) + } + } + reqwest::StatusCode::UNAUTHORIZED => { + Err(anyhow::anyhow!("API key validation failed: Unauthorized (401)")) + } + reqwest::StatusCode::FORBIDDEN => { + Err(anyhow::anyhow!("API key validation failed: Forbidden (403)")) + } + status => { + Err(anyhow::anyhow!( + "API key validation failed with status: {} - {}", + status.as_u16(), + status.canonical_reason().unwrap_or("Unknown") + )) + } + } +} diff --git a/src/bpf/filter.bpf.c b/src/bpf/filter.bpf.c index 981e2eb..58bfe8c 100644 --- a/src/bpf/filter.bpf.c +++ b/src/bpf/filter.bpf.c @@ -1,766 +1,766 @@ -#undef bpf // Undefine bpf macro to avoid conflict with struct netns_bpf bpf in vmlinux.h -#include "vmlinux.h" -#include -#include -#include "filter.h" - -#define NF_DROP 0 -#define NF_ACCEPT 1 -#define ETH_P_IP 0x0800 -#define ETH_P_IPV6 0x86DD -#define IP_MF 0x2000 -#define IP_OFFSET 0x1FFF -#define NEXTHDR_FRAGMENT 44 - -// TCP fingerprinting feature flag - comment out to disable and reduce program size -#define ENABLE_TCP_FINGERPRINTING - -// TCP fingerprinting constants -#define TCP_FINGERPRINT_MAX_ENTRIES 10000 -#define TCP_FP_KEY_SIZE 20 // 4 bytes IP + 2 bytes port + 14 bytes fingerprint -#define TCP_FP_MAX_OPTIONS 10 -#define TCP_FP_MAX_OPTION_LEN 16 // Reduced from 40 to minimize BPF instruction count - - -// Fragmentation checks removed - not currently used - - -struct lpm_key { - __u32 prefixlen; - __be32 addr; -}; - -struct lpm_key_v6 { - __u32 prefixlen; - __u8 addr[16]; -}; - -// TCP fingerprinting structures -struct tcp_fingerprint_key { - __be32 src_ip; // Source IP address (IPv4) - __be16 src_port; // Source port - __u8 fingerprint[14]; // TCP fingerprint string (null-terminated) -}; - -struct tcp_fingerprint_key_v6 { - __u8 src_ip[16]; // Source IP address (IPv6) - __be16 src_port; // Source port - __u8 fingerprint[14]; // TCP fingerprint string (null-terminated) -}; - -struct tcp_fingerprint_data { - __u64 first_seen; // Timestamp of first packet - __u64 last_seen; // Timestamp of last packet - __u32 packet_count; // Number of packets seen - __u16 ttl; // Initial TTL - __u16 mss; // Maximum Segment Size - __u16 window_size; // TCP window size - __u8 window_scale; // Window scaling factor - __u8 options_len; // Length of TCP options - __u8 options[TCP_FP_MAX_OPTION_LEN]; // TCP options data -}; - -struct tcp_syn_stats { - __u64 total_syns; - __u64 unique_fingerprints; - __u64 last_reset; -}; - -// IPv4 maps: permanently banned and recently banned -struct { - __uint(type, BPF_MAP_TYPE_LPM_TRIE); - __uint(max_entries, CITADEL_IP_MAP_MAX); - __uint(map_flags, BPF_F_NO_PREALLOC); - __type(key, struct lpm_key); // IPv4 address in network byte order - __type(value, ip_flag_t); // presence flag (1) -} banned_ips SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_LPM_TRIE); - __uint(max_entries, CITADEL_IP_MAP_MAX); - __uint(map_flags, BPF_F_NO_PREALLOC); - __type(key, struct lpm_key); - __type(value, ip_flag_t); -} recently_banned_ips SEC(".maps"); - -// IPv6 maps: permanently banned and recently banned -struct { - __uint(type, BPF_MAP_TYPE_LPM_TRIE); - __uint(max_entries, CITADEL_IP_MAP_MAX); - __uint(map_flags, BPF_F_NO_PREALLOC); - __type(key, struct lpm_key_v6); - __type(value, ip_flag_t); -} banned_ips_v6 SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_LPM_TRIE); - __uint(max_entries, CITADEL_IP_MAP_MAX); - __uint(map_flags, BPF_F_NO_PREALLOC); - __type(key, struct lpm_key_v6); - __type(value, ip_flag_t); -} recently_banned_ips_v6 SEC(".maps"); - -// Remove dynptr helpers, not used in XDP manual parsing -// extern int bpf_dynptr_from_skb(struct __sk_buff *skb, __u64 flags, -// struct bpf_dynptr *ptr__uninit) __ksym; -// extern void *bpf_dynptr_slice(const struct bpf_dynptr *ptr, uint32_t offset, -// void *buffer, uint32_t buffer__sz) __ksym; - -volatile int shootdowns = 0; - -// Statistics maps for tracking access rule hits -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} ipv4_banned_stats SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} ipv4_recently_banned_stats SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} ipv6_banned_stats SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} ipv6_recently_banned_stats SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} total_packets_processed SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} total_packets_dropped SEC(".maps"); - -// TCP fingerprinting maps -struct { - __uint(type, BPF_MAP_TYPE_LRU_HASH); - __uint(max_entries, TCP_FINGERPRINT_MAX_ENTRIES); - __type(key, struct tcp_fingerprint_key); - __type(value, struct tcp_fingerprint_data); -} tcp_fingerprints SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_LRU_HASH); - __uint(max_entries, TCP_FINGERPRINT_MAX_ENTRIES); - __type(key, struct tcp_fingerprint_key_v6); - __type(value, struct tcp_fingerprint_data); -} tcp_fingerprints_v6 SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, struct tcp_syn_stats); -} tcp_syn_stats SEC(".maps"); - -// Blocked TCP fingerprint maps (only store the fingerprint string, not per-IP) -struct { - __uint(type, BPF_MAP_TYPE_HASH); - __uint(max_entries, 10000); // Store up to 10k blocked fingerprint patterns - __type(key, __u8[14]); // TCP fingerprint string (14 bytes) - __type(value, __u8); // Flag (1 = blocked) -} blocked_tcp_fingerprints SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_HASH); - __uint(max_entries, 10000); - __type(key, __u8[14]); // TCP fingerprint string (14 bytes) - __type(value, __u8); // Flag (1 = blocked) -} blocked_tcp_fingerprints_v6 SEC(".maps"); - -// Statistics for TCP fingerprint blocks -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} tcp_fingerprint_blocks_ipv4 SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} tcp_fingerprint_blocks_ipv6 SEC(".maps"); - -// Maps to track dropped IP addresses with counters -struct { - __uint(type, BPF_MAP_TYPE_LRU_HASH); - __uint(max_entries, 1000); // Track up to 1000 unique dropped IPs - __type(key, __be32); // IPv4 address - __type(value, __u64); // Drop count -} dropped_ipv4_addresses SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_LRU_HASH); - __uint(max_entries, 1000); // Track up to 1000 unique dropped IPv6s - __type(key, __u8[16]); // IPv6 address - __type(value, __u64); // Drop count -} dropped_ipv6_addresses SEC(".maps"); - -/* - * Helper for bounds checking and advancing a cursor. - * - * @cursor: pointer to current parsing position - * @end: pointer to end of packet data - * @len: length of the struct to read - * - * Returns a pointer to the struct if it's within bounds, - * and advances the cursor. Returns NULL otherwise. - */ -static void *parse_and_advance(void **cursor, void *end, __u32 len) -{ - void *current = *cursor; - if (current + len > end) - return NULL; - *cursor = current + len; - return current; -} - -/* - * Helper functions for incrementing statistics counters - */ -static void increment_ipv4_banned_stats(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&ipv4_banned_stats, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static void increment_ipv4_recently_banned_stats(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&ipv4_recently_banned_stats, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static void increment_ipv6_banned_stats(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&ipv6_banned_stats, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static void increment_ipv6_recently_banned_stats(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&ipv6_recently_banned_stats, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static void increment_total_packets_processed(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&total_packets_processed, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static void increment_total_packets_dropped(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&total_packets_dropped, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static void increment_dropped_ipv4_address(__be32 ip_addr) -{ - __u64 *value = bpf_map_lookup_elem(&dropped_ipv4_addresses, &ip_addr); - if (value) { - __sync_fetch_and_add(value, 1); - } else { - // First time dropping this IP, initialize counter - __u64 initial_count = 1; - bpf_map_update_elem(&dropped_ipv4_addresses, &ip_addr, &initial_count, BPF_ANY); - } -} - -static void increment_dropped_ipv6_address(struct in6_addr ip_addr) -{ - __u8 *addr_bytes = (__u8 *)&ip_addr; - __u64 *value = bpf_map_lookup_elem(&dropped_ipv6_addresses, addr_bytes); - if (value) { - __sync_fetch_and_add(value, 1); - } else { - // First time dropping this IP, initialize counter - __u64 initial_count = 1; - bpf_map_update_elem(&dropped_ipv6_addresses, addr_bytes, &initial_count, BPF_ANY); - } -} - -/* - * TCP fingerprinting helper functions - */ -static void increment_tcp_syn_stats(void) -{ - __u32 key = 0; - struct tcp_syn_stats *stats = bpf_map_lookup_elem(&tcp_syn_stats, &key); - if (stats) { - __sync_fetch_and_add(&stats->total_syns, 1); - } else { - struct tcp_syn_stats new_stats = {0}; - new_stats.total_syns = 1; - bpf_map_update_elem(&tcp_syn_stats, &key, &new_stats, BPF_ANY); - } -} - -// increment_unique_fingerprints removed - fingerprint recording disabled - -static void increment_tcp_fingerprint_blocks_ipv4(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&tcp_fingerprint_blocks_ipv4, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static void increment_tcp_fingerprint_blocks_ipv6(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&tcp_fingerprint_blocks_ipv6, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static int parse_tcp_mss_wscale(struct tcphdr *tcp, void *data_end, __u16 *mss_out, __u8 *wscale_out) -{ - __u8 *ptr = (__u8 *)tcp + sizeof(struct tcphdr); - __u32 options_len = (tcp->doff * 4) - sizeof(struct tcphdr); - __u8 *end = ptr + options_len; - - // Ensure we don't exceed packet bounds - if (end > (__u8 *)data_end) { - end = (__u8 *)data_end; - } - - // Safety check - if (ptr >= end || ptr >= (__u8 *)data_end) { - return 0; - } - - // Parse options - limit to 10 iterations to reduce instruction count - #pragma unroll - for (int i = 0; i < 10; i++) { - if (ptr >= end || ptr >= (__u8 *)data_end) break; - if (ptr + 1 > (__u8 *)data_end) break; - - __u8 kind = *ptr; - if (kind == 0) break; // End of options - - if (kind == 1) { - // NOP option - ptr++; - continue; - } - - // Check bounds for option length - if (ptr + 2 > (__u8 *)data_end) break; - __u8 len = *(ptr + 1); - if (len < 2 || ptr + len > (__u8 *)data_end) break; - - // MSS option (kind=2, len=4) - if (kind == 2 && len == 4 && ptr + 4 <= (__u8 *)data_end) { - *mss_out = (*(ptr + 2) << 8) | *(ptr + 3); - } - // Window scale option (kind=3, len=3) - else if (kind == 3 && len == 3 && ptr + 3 <= (__u8 *)data_end) { - *wscale_out = *(ptr + 2); - } - - ptr += len; - } - - return 0; -} - -static void generate_tcp_fingerprint(struct tcphdr *tcp, void *data_end, __u16 ttl, __u8 *fingerprint) -{ - // Generate JA4T-style fingerprint: ttl:mss:window:scale - __u16 mss = 0; - __u8 window_scale = 0; - - // Parse TCP options to extract MSS and window scaling - parse_tcp_mss_wscale(tcp, data_end, &mss, &window_scale); - - // Generate fingerprint string manually (BPF doesn't support complex formatting) - __u16 window = bpf_ntohs(tcp->window); - - // Format: "ttl:mss:window:scale" (max 14 chars) - fingerprint[0] = '0' + (ttl / 100); - fingerprint[1] = '0' + ((ttl / 10) % 10); - fingerprint[2] = '0' + (ttl % 10); - fingerprint[3] = ':'; - fingerprint[4] = '0' + (mss / 1000); - fingerprint[5] = '0' + ((mss / 100) % 10); - fingerprint[6] = '0' + ((mss / 10) % 10); - fingerprint[7] = '0' + (mss % 10); - fingerprint[8] = ':'; - fingerprint[9] = '0' + (window / 10000); - fingerprint[10] = '0' + ((window / 1000) % 10); - fingerprint[11] = '0' + ((window / 100) % 10); - fingerprint[12] = '0' + ((window / 10) % 10); - fingerprint[13] = '0' + (window % 10); - // Note: window_scale is not included due to space constraints -} - -/* - * Check if a TCP fingerprint is blocked (IPv4) - * Returns true if the fingerprint should be blocked - */ -static bool is_tcp_fingerprint_blocked(__u8 *fingerprint) -{ - __u8 *blocked = bpf_map_lookup_elem(&blocked_tcp_fingerprints, fingerprint); - return (blocked != NULL && *blocked == 1); -} - -/* - * Check if a TCP fingerprint is blocked (IPv6) - * Returns true if the fingerprint should be blocked - */ -static bool is_tcp_fingerprint_blocked_v6(__u8 *fingerprint) -{ - __u8 *blocked = bpf_map_lookup_elem(&blocked_tcp_fingerprints_v6, fingerprint); - return (blocked != NULL && *blocked == 1); -} - -// TCP fingerprint recording info struct to pass data with fewer args -struct fp_record_info { - __be32 src_ip_v4; - __be16 src_port; - __u16 ttl; - __u8 fingerprint[14]; - __u8 src_ip_v6[16]; -}; - -/* - * Record TCP fingerprint for monitoring (IPv4) - * Uses struct to pass data within BPF 5-arg limit - */ -static void record_tcp_fingerprint_v4(struct fp_record_info *info, - struct tcphdr *tcp, void *data_end) -{ - // Skip localhost (127.0.0.0/8) - if ((info->src_ip_v4 & bpf_htonl(0xff000000)) == bpf_htonl(0x7f000000)) - return; - - struct tcp_fingerprint_key key = {0}; - key.src_ip = info->src_ip_v4; - key.src_port = info->src_port; - __builtin_memcpy(key.fingerprint, info->fingerprint, 14); - - // Only create new entries - if (bpf_map_lookup_elem(&tcp_fingerprints, &key)) - return; - - struct tcp_fingerprint_data data = {0}; - data.first_seen = bpf_ktime_get_ns(); - data.last_seen = data.first_seen; - data.packet_count = 1; - data.ttl = info->ttl; - data.window_size = bpf_ntohs(tcp->window); - parse_tcp_mss_wscale(tcp, data_end, &data.mss, &data.window_scale); - - bpf_map_update_elem(&tcp_fingerprints, &key, &data, BPF_NOEXIST); -} - -/* - * Record TCP fingerprint for monitoring (IPv6) - */ -static void record_tcp_fingerprint_v6(struct fp_record_info *info, - struct tcphdr *tcp, void *data_end) -{ - struct tcp_fingerprint_key_v6 key = {0}; - __builtin_memcpy(key.src_ip, info->src_ip_v6, 16); - key.src_port = info->src_port; - __builtin_memcpy(key.fingerprint, info->fingerprint, 14); - - // Only create new entries - if (bpf_map_lookup_elem(&tcp_fingerprints_v6, &key)) - return; - - struct tcp_fingerprint_data data = {0}; - data.first_seen = bpf_ktime_get_ns(); - data.last_seen = data.first_seen; - data.packet_count = 1; - data.ttl = info->ttl; - data.window_size = bpf_ntohs(tcp->window); - parse_tcp_mss_wscale(tcp, data_end, &data.mss, &data.window_scale); - - bpf_map_update_elem(&tcp_fingerprints_v6, &key, &data, BPF_NOEXIST); -} - -SEC("xdp") -int arxignis_xdp_filter(struct xdp_md *ctx) -{ - // This filter is designed to only block incoming traffic - // It should be attached only to ingress hooks, not egress - // The filtering logic below blocks packets based on source IP addresses - // - // IP Version Support: - // - Supports IPv4-only, IPv6-only, and hybrid (both) modes - // - Note: XDP requires IPv6 to be enabled at kernel level for attachment, - // even when processing only IPv4 packets. This is a kernel limitation. - // - The BPF program processes both IPv4 and IPv6 packets based on the - // ethernet protocol type (ETH_P_IP for IPv4, ETH_P_IPV6 for IPv6) - - void *data_end = (void *)(long)ctx->data_end; - void *cursor = (void *)(long)ctx->data; - - // Debug: Count all packets - __u32 zero = 0; - __u32 *packet_count = bpf_map_lookup_elem(&total_packets_processed, &zero); - if (packet_count) { - __sync_fetch_and_add(packet_count, 1); - } - - struct ethhdr *eth = parse_and_advance(&cursor, data_end, sizeof(*eth)); - if (!eth) - return XDP_PASS; - - __u16 h_proto = eth->h_proto; - - // Increment total packets processed counter - increment_total_packets_processed(); - - if (h_proto == bpf_htons(ETH_P_IP)) { - struct iphdr *iph = parse_and_advance(&cursor, data_end, sizeof(*iph)); - if (!iph) - return XDP_PASS; - - struct lpm_key key = { - .prefixlen = 32, - .addr = iph->saddr, - }; - - if (bpf_map_lookup_elem(&banned_ips, &key)) { - increment_ipv4_banned_stats(); - increment_total_packets_dropped(); - increment_dropped_ipv4_address(iph->saddr); - //bpf_printk("XDP: BLOCKED incoming permanently banned IPv4 %pI4", &iph->saddr); - return XDP_DROP; - } - - if (bpf_map_lookup_elem(&recently_banned_ips, &key)) { - increment_ipv4_recently_banned_stats(); - // Block UDP and ICMP from recently banned IPs, but allow DNS - if (iph->protocol == IPPROTO_UDP) { - struct udphdr *udph = parse_and_advance(&cursor, data_end, sizeof(*udph)); - if (udph && udph->dest == bpf_htons(53)) { - return XDP_PASS; // Allow DNS responses - } - // Block other UDP traffic - ip_flag_t one = 1; - bpf_map_update_elem(&banned_ips, &key, &one, BPF_ANY); - bpf_map_delete_elem(&recently_banned_ips, &key); - increment_total_packets_dropped(); - increment_dropped_ipv4_address(iph->saddr); - //bpf_printk("XDP: BLOCKED incoming UDP from recently banned IPv4 %pI4, promoted to permanent ban", &iph->saddr); - return XDP_DROP; - } - if (iph->protocol == IPPROTO_ICMP) { - ip_flag_t one = 1; - bpf_map_update_elem(&banned_ips, &key, &one, BPF_ANY); - bpf_map_delete_elem(&recently_banned_ips, &key); - increment_total_packets_dropped(); - increment_dropped_ipv4_address(iph->saddr); - //bpf_printk("XDP: BLOCKED incoming ICMP from recently banned IPv4 %pI4, promoted to permanent ban", &iph->saddr); - return XDP_DROP; - } - // For TCP, promote to banned on FIN/RST - if (iph->protocol == IPPROTO_TCP) { - struct tcphdr *tcph = parse_and_advance(&cursor, data_end, sizeof(*tcph)); - if (tcph && (tcph->fin || tcph->rst)) { - ip_flag_t one = 1; - bpf_map_update_elem(&banned_ips, &key, &one, BPF_ANY); - bpf_map_delete_elem(&recently_banned_ips, &key); - increment_total_packets_dropped(); - increment_dropped_ipv4_address(iph->saddr); - } - } - return XDP_PASS; - } - - // Perform TCP fingerprinting ONLY on SYN packets - if (iph->protocol == IPPROTO_TCP) { - struct tcphdr *tcph = parse_and_advance(&cursor, data_end, sizeof(*tcph)); - if (tcph) { - // Only fingerprint SYN packets (not SYN-ACK) to capture MSS/WSCALE - if (tcph->syn && !tcph->ack) { - increment_tcp_syn_stats(); - - // Generate fingerprint to check if blocked - __u8 fingerprint[14] = {0}; - generate_tcp_fingerprint(tcph, data_end, iph->ttl, fingerprint); - - // Check if this TCP fingerprint is blocked - if (is_tcp_fingerprint_blocked(fingerprint)) { - increment_tcp_fingerprint_blocks_ipv4(); - increment_total_packets_dropped(); - increment_dropped_ipv4_address(iph->saddr); - return XDP_DROP; - } - // Record fingerprint for monitoring - struct fp_record_info fp_info = {0}; - fp_info.src_ip_v4 = iph->saddr; - fp_info.src_port = tcph->source; - fp_info.ttl = iph->ttl; - __builtin_memcpy(fp_info.fingerprint, fingerprint, 14); - record_tcp_fingerprint_v4(&fp_info, tcph, data_end); - } - } - } - - return XDP_PASS; - } - else if (h_proto == bpf_htons(ETH_P_IPV6)) { - struct ipv6hdr *ip6h = parse_and_advance(&cursor, data_end, sizeof(*ip6h)); - if (!ip6h) - return XDP_PASS; - - // Always allow DNS traffic (UDP port 53) to pass through - if (ip6h->nexthdr == IPPROTO_UDP) { - struct udphdr *udph = parse_and_advance(&cursor, data_end, sizeof(*udph)); - if (udph && (udph->dest == bpf_htons(53) || udph->source == bpf_htons(53))) { - return XDP_PASS; // Always allow DNS traffic - } - } - - // Check banned/recently banned maps by source IPv6 - struct lpm_key_v6 key6 = { - .prefixlen = 128, - }; - __u8 *src_addr = (__u8 *)&ip6h->saddr; - __builtin_memcpy(key6.addr, src_addr, 16); - - if (bpf_map_lookup_elem(&banned_ips_v6, &key6)) { - increment_ipv6_banned_stats(); - increment_total_packets_dropped(); - increment_dropped_ipv6_address(ip6h->saddr); - //bpf_printk("XDP: BLOCKED incoming permanently banned IPv6"); - return XDP_DROP; - } - - if (bpf_map_lookup_elem(&recently_banned_ips_v6, &key6)) { - increment_ipv6_recently_banned_stats(); - // Block UDP and ICMP from recently banned IPv6 IPs, but allow DNS - if (ip6h->nexthdr == IPPROTO_UDP) { - struct udphdr *udph = parse_and_advance(&cursor, data_end, sizeof(*udph)); - if (udph && udph->dest == bpf_htons(53)) { - return XDP_PASS; // Allow DNS responses - } - // Block other UDP traffic - ip_flag_t one = 1; - bpf_map_update_elem(&banned_ips_v6, &key6, &one, BPF_ANY); - bpf_map_delete_elem(&recently_banned_ips_v6, &key6); - increment_total_packets_dropped(); - increment_dropped_ipv6_address(ip6h->saddr); - //bpf_printk("XDP: BLOCKED incoming UDP from recently banned IPv6, promoted to permanent ban"); - return XDP_DROP; - } - if (ip6h->nexthdr == 58) { // 58 = IPPROTO_ICMPV6 - ip_flag_t one = 1; - bpf_map_update_elem(&banned_ips_v6, &key6, &one, BPF_ANY); - bpf_map_delete_elem(&recently_banned_ips_v6, &key6); - increment_total_packets_dropped(); - increment_dropped_ipv6_address(ip6h->saddr); - //bpf_printk("XDP: BLOCKED incoming ICMPv6 from recently banned IPv6, promoted to permanent ban"); - return XDP_DROP; - } - // For TCP, only promote to banned on FIN/RST - if (ip6h->nexthdr == IPPROTO_TCP) { - struct tcphdr *tcph = parse_and_advance(&cursor, data_end, sizeof(*tcph)); - if (tcph) { - if (tcph->fin || tcph->rst) { - ip_flag_t one = 1; - bpf_map_update_elem(&banned_ips_v6, &key6, &one, BPF_ANY); - bpf_map_delete_elem(&recently_banned_ips_v6, &key6); - increment_total_packets_dropped(); - increment_dropped_ipv6_address(ip6h->saddr); - //bpf_printk("XDP: TCP FIN/RST from incoming recently banned IPv6, promoted to permanent ban"); - } - } - } - return XDP_PASS; // Allow if recently banned - } - - // Perform TCP fingerprinting on IPv6 TCP packets - if (ip6h->nexthdr == IPPROTO_TCP) { - struct tcphdr *tcph = parse_and_advance(&cursor, data_end, sizeof(*tcph)); - if (tcph) { - // Perform TCP fingerprinting ONLY on SYN packets (not SYN-ACK) - // This ensures we capture the initial handshake with MSS/WSCALE - if (tcph->syn && !tcph->ack) { - // Skip IPv6 localhost traffic (::1) - simplified check - __u8 *src_addr = (__u8 *)&ip6h->saddr; - // Quick localhost check: first 8 bytes zero, bytes 8-14 zero, byte 15 is 1 - if (src_addr[0] == 0 && src_addr[7] == 0 && - src_addr[8] == 0 && src_addr[14] == 0 && src_addr[15] == 1) { - return XDP_PASS; - } - - // Extract TTL from IPv6 hop limit - __u16 ttl = ip6h->hop_limit; - - // Generate fingerprint to check if blocked - __u8 fingerprint[14] = {0}; - generate_tcp_fingerprint(tcph, data_end, ttl, fingerprint); - - // Check if this TCP fingerprint is blocked - if (is_tcp_fingerprint_blocked_v6(fingerprint)) { - increment_tcp_fingerprint_blocks_ipv6(); - increment_total_packets_dropped(); - increment_dropped_ipv6_address(ip6h->saddr); - return XDP_DROP; - } - // Record fingerprint for monitoring - struct fp_record_info fp_info = {0}; - __builtin_memcpy(fp_info.src_ip_v6, src_addr, 16); - fp_info.src_port = tcph->source; - fp_info.ttl = ttl; - __builtin_memcpy(fp_info.fingerprint, fingerprint, 14); - record_tcp_fingerprint_v6(&fp_info, tcph, data_end); - } - } - } - - return XDP_PASS; - } - - return XDP_PASS; - // return XDP_ABORTED; -} - -char _license[] SEC("license") = "GPL"; +#undef bpf // Undefine bpf macro to avoid conflict with struct netns_bpf bpf in vmlinux.h +#include "vmlinux.h" +#include +#include +#include "filter.h" + +#define NF_DROP 0 +#define NF_ACCEPT 1 +#define ETH_P_IP 0x0800 +#define ETH_P_IPV6 0x86DD +#define IP_MF 0x2000 +#define IP_OFFSET 0x1FFF +#define NEXTHDR_FRAGMENT 44 + +// TCP fingerprinting feature flag - comment out to disable and reduce program size +#define ENABLE_TCP_FINGERPRINTING + +// TCP fingerprinting constants +#define TCP_FINGERPRINT_MAX_ENTRIES 10000 +#define TCP_FP_KEY_SIZE 20 // 4 bytes IP + 2 bytes port + 14 bytes fingerprint +#define TCP_FP_MAX_OPTIONS 10 +#define TCP_FP_MAX_OPTION_LEN 16 // Reduced from 40 to minimize BPF instruction count + + +// Fragmentation checks removed - not currently used + + +struct lpm_key { + __u32 prefixlen; + __be32 addr; +}; + +struct lpm_key_v6 { + __u32 prefixlen; + __u8 addr[16]; +}; + +// TCP fingerprinting structures +struct tcp_fingerprint_key { + __be32 src_ip; // Source IP address (IPv4) + __be16 src_port; // Source port + __u8 fingerprint[14]; // TCP fingerprint string (null-terminated) +}; + +struct tcp_fingerprint_key_v6 { + __u8 src_ip[16]; // Source IP address (IPv6) + __be16 src_port; // Source port + __u8 fingerprint[14]; // TCP fingerprint string (null-terminated) +}; + +struct tcp_fingerprint_data { + __u64 first_seen; // Timestamp of first packet + __u64 last_seen; // Timestamp of last packet + __u32 packet_count; // Number of packets seen + __u16 ttl; // Initial TTL + __u16 mss; // Maximum Segment Size + __u16 window_size; // TCP window size + __u8 window_scale; // Window scaling factor + __u8 options_len; // Length of TCP options + __u8 options[TCP_FP_MAX_OPTION_LEN]; // TCP options data +}; + +struct tcp_syn_stats { + __u64 total_syns; + __u64 unique_fingerprints; + __u64 last_reset; +}; + +// IPv4 maps: permanently banned and recently banned +struct { + __uint(type, BPF_MAP_TYPE_LPM_TRIE); + __uint(max_entries, CITADEL_IP_MAP_MAX); + __uint(map_flags, BPF_F_NO_PREALLOC); + __type(key, struct lpm_key); // IPv4 address in network byte order + __type(value, ip_flag_t); // presence flag (1) +} banned_ips SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_LPM_TRIE); + __uint(max_entries, CITADEL_IP_MAP_MAX); + __uint(map_flags, BPF_F_NO_PREALLOC); + __type(key, struct lpm_key); + __type(value, ip_flag_t); +} recently_banned_ips SEC(".maps"); + +// IPv6 maps: permanently banned and recently banned +struct { + __uint(type, BPF_MAP_TYPE_LPM_TRIE); + __uint(max_entries, CITADEL_IP_MAP_MAX); + __uint(map_flags, BPF_F_NO_PREALLOC); + __type(key, struct lpm_key_v6); + __type(value, ip_flag_t); +} banned_ips_v6 SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_LPM_TRIE); + __uint(max_entries, CITADEL_IP_MAP_MAX); + __uint(map_flags, BPF_F_NO_PREALLOC); + __type(key, struct lpm_key_v6); + __type(value, ip_flag_t); +} recently_banned_ips_v6 SEC(".maps"); + +// Remove dynptr helpers, not used in XDP manual parsing +// extern int bpf_dynptr_from_skb(struct __sk_buff *skb, __u64 flags, +// struct bpf_dynptr *ptr__uninit) __ksym; +// extern void *bpf_dynptr_slice(const struct bpf_dynptr *ptr, uint32_t offset, +// void *buffer, uint32_t buffer__sz) __ksym; + +volatile int shootdowns = 0; + +// Statistics maps for tracking access rule hits +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} ipv4_banned_stats SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} ipv4_recently_banned_stats SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} ipv6_banned_stats SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} ipv6_recently_banned_stats SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} total_packets_processed SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} total_packets_dropped SEC(".maps"); + +// TCP fingerprinting maps +struct { + __uint(type, BPF_MAP_TYPE_LRU_HASH); + __uint(max_entries, TCP_FINGERPRINT_MAX_ENTRIES); + __type(key, struct tcp_fingerprint_key); + __type(value, struct tcp_fingerprint_data); +} tcp_fingerprints SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_LRU_HASH); + __uint(max_entries, TCP_FINGERPRINT_MAX_ENTRIES); + __type(key, struct tcp_fingerprint_key_v6); + __type(value, struct tcp_fingerprint_data); +} tcp_fingerprints_v6 SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, struct tcp_syn_stats); +} tcp_syn_stats SEC(".maps"); + +// Blocked TCP fingerprint maps (only store the fingerprint string, not per-IP) +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, 10000); // Store up to 10k blocked fingerprint patterns + __type(key, __u8[14]); // TCP fingerprint string (14 bytes) + __type(value, __u8); // Flag (1 = blocked) +} blocked_tcp_fingerprints SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, 10000); + __type(key, __u8[14]); // TCP fingerprint string (14 bytes) + __type(value, __u8); // Flag (1 = blocked) +} blocked_tcp_fingerprints_v6 SEC(".maps"); + +// Statistics for TCP fingerprint blocks +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} tcp_fingerprint_blocks_ipv4 SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} tcp_fingerprint_blocks_ipv6 SEC(".maps"); + +// Maps to track dropped IP addresses with counters +struct { + __uint(type, BPF_MAP_TYPE_LRU_HASH); + __uint(max_entries, 1000); // Track up to 1000 unique dropped IPs + __type(key, __be32); // IPv4 address + __type(value, __u64); // Drop count +} dropped_ipv4_addresses SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_LRU_HASH); + __uint(max_entries, 1000); // Track up to 1000 unique dropped IPv6s + __type(key, __u8[16]); // IPv6 address + __type(value, __u64); // Drop count +} dropped_ipv6_addresses SEC(".maps"); + +/* + * Helper for bounds checking and advancing a cursor. + * + * @cursor: pointer to current parsing position + * @end: pointer to end of packet data + * @len: length of the struct to read + * + * Returns a pointer to the struct if it's within bounds, + * and advances the cursor. Returns NULL otherwise. + */ +static void *parse_and_advance(void **cursor, void *end, __u32 len) +{ + void *current = *cursor; + if (current + len > end) + return NULL; + *cursor = current + len; + return current; +} + +/* + * Helper functions for incrementing statistics counters + */ +static void increment_ipv4_banned_stats(void) +{ + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&ipv4_banned_stats, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static void increment_ipv4_recently_banned_stats(void) +{ + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&ipv4_recently_banned_stats, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static void increment_ipv6_banned_stats(void) +{ + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&ipv6_banned_stats, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static void increment_ipv6_recently_banned_stats(void) +{ + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&ipv6_recently_banned_stats, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static void increment_total_packets_processed(void) +{ + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&total_packets_processed, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static void increment_total_packets_dropped(void) +{ + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&total_packets_dropped, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static void increment_dropped_ipv4_address(__be32 ip_addr) +{ + __u64 *value = bpf_map_lookup_elem(&dropped_ipv4_addresses, &ip_addr); + if (value) { + __sync_fetch_and_add(value, 1); + } else { + // First time dropping this IP, initialize counter + __u64 initial_count = 1; + bpf_map_update_elem(&dropped_ipv4_addresses, &ip_addr, &initial_count, BPF_ANY); + } +} + +static void increment_dropped_ipv6_address(struct in6_addr ip_addr) +{ + __u8 *addr_bytes = (__u8 *)&ip_addr; + __u64 *value = bpf_map_lookup_elem(&dropped_ipv6_addresses, addr_bytes); + if (value) { + __sync_fetch_and_add(value, 1); + } else { + // First time dropping this IP, initialize counter + __u64 initial_count = 1; + bpf_map_update_elem(&dropped_ipv6_addresses, addr_bytes, &initial_count, BPF_ANY); + } +} + +/* + * TCP fingerprinting helper functions + */ +static void increment_tcp_syn_stats(void) +{ + __u32 key = 0; + struct tcp_syn_stats *stats = bpf_map_lookup_elem(&tcp_syn_stats, &key); + if (stats) { + __sync_fetch_and_add(&stats->total_syns, 1); + } else { + struct tcp_syn_stats new_stats = {0}; + new_stats.total_syns = 1; + bpf_map_update_elem(&tcp_syn_stats, &key, &new_stats, BPF_ANY); + } +} + +// increment_unique_fingerprints removed - fingerprint recording disabled + +static void increment_tcp_fingerprint_blocks_ipv4(void) +{ + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&tcp_fingerprint_blocks_ipv4, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static void increment_tcp_fingerprint_blocks_ipv6(void) +{ + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&tcp_fingerprint_blocks_ipv6, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static int parse_tcp_mss_wscale(struct tcphdr *tcp, void *data_end, __u16 *mss_out, __u8 *wscale_out) +{ + __u8 *ptr = (__u8 *)tcp + sizeof(struct tcphdr); + __u32 options_len = (tcp->doff * 4) - sizeof(struct tcphdr); + __u8 *end = ptr + options_len; + + // Ensure we don't exceed packet bounds + if (end > (__u8 *)data_end) { + end = (__u8 *)data_end; + } + + // Safety check + if (ptr >= end || ptr >= (__u8 *)data_end) { + return 0; + } + + // Parse options - limit to 10 iterations to reduce instruction count + #pragma unroll + for (int i = 0; i < 10; i++) { + if (ptr >= end || ptr >= (__u8 *)data_end) break; + if (ptr + 1 > (__u8 *)data_end) break; + + __u8 kind = *ptr; + if (kind == 0) break; // End of options + + if (kind == 1) { + // NOP option + ptr++; + continue; + } + + // Check bounds for option length + if (ptr + 2 > (__u8 *)data_end) break; + __u8 len = *(ptr + 1); + if (len < 2 || ptr + len > (__u8 *)data_end) break; + + // MSS option (kind=2, len=4) + if (kind == 2 && len == 4 && ptr + 4 <= (__u8 *)data_end) { + *mss_out = (*(ptr + 2) << 8) | *(ptr + 3); + } + // Window scale option (kind=3, len=3) + else if (kind == 3 && len == 3 && ptr + 3 <= (__u8 *)data_end) { + *wscale_out = *(ptr + 2); + } + + ptr += len; + } + + return 0; +} + +static void generate_tcp_fingerprint(struct tcphdr *tcp, void *data_end, __u16 ttl, __u8 *fingerprint) +{ + // Generate JA4T-style fingerprint: ttl:mss:window:scale + __u16 mss = 0; + __u8 window_scale = 0; + + // Parse TCP options to extract MSS and window scaling + parse_tcp_mss_wscale(tcp, data_end, &mss, &window_scale); + + // Generate fingerprint string manually (BPF doesn't support complex formatting) + __u16 window = bpf_ntohs(tcp->window); + + // Format: "ttl:mss:window:scale" (max 14 chars) + fingerprint[0] = '0' + (ttl / 100); + fingerprint[1] = '0' + ((ttl / 10) % 10); + fingerprint[2] = '0' + (ttl % 10); + fingerprint[3] = ':'; + fingerprint[4] = '0' + (mss / 1000); + fingerprint[5] = '0' + ((mss / 100) % 10); + fingerprint[6] = '0' + ((mss / 10) % 10); + fingerprint[7] = '0' + (mss % 10); + fingerprint[8] = ':'; + fingerprint[9] = '0' + (window / 10000); + fingerprint[10] = '0' + ((window / 1000) % 10); + fingerprint[11] = '0' + ((window / 100) % 10); + fingerprint[12] = '0' + ((window / 10) % 10); + fingerprint[13] = '0' + (window % 10); + // Note: window_scale is not included due to space constraints +} + +/* + * Check if a TCP fingerprint is blocked (IPv4) + * Returns true if the fingerprint should be blocked + */ +static bool is_tcp_fingerprint_blocked(__u8 *fingerprint) +{ + __u8 *blocked = bpf_map_lookup_elem(&blocked_tcp_fingerprints, fingerprint); + return (blocked != NULL && *blocked == 1); +} + +/* + * Check if a TCP fingerprint is blocked (IPv6) + * Returns true if the fingerprint should be blocked + */ +static bool is_tcp_fingerprint_blocked_v6(__u8 *fingerprint) +{ + __u8 *blocked = bpf_map_lookup_elem(&blocked_tcp_fingerprints_v6, fingerprint); + return (blocked != NULL && *blocked == 1); +} + +// TCP fingerprint recording info struct to pass data with fewer args +struct fp_record_info { + __be32 src_ip_v4; + __be16 src_port; + __u16 ttl; + __u8 fingerprint[14]; + __u8 src_ip_v6[16]; +}; + +/* + * Record TCP fingerprint for monitoring (IPv4) + * Uses struct to pass data within BPF 5-arg limit + */ +static void record_tcp_fingerprint_v4(struct fp_record_info *info, + struct tcphdr *tcp, void *data_end) +{ + // Skip localhost (127.0.0.0/8) + if ((info->src_ip_v4 & bpf_htonl(0xff000000)) == bpf_htonl(0x7f000000)) + return; + + struct tcp_fingerprint_key key = {0}; + key.src_ip = info->src_ip_v4; + key.src_port = info->src_port; + __builtin_memcpy(key.fingerprint, info->fingerprint, 14); + + // Only create new entries + if (bpf_map_lookup_elem(&tcp_fingerprints, &key)) + return; + + struct tcp_fingerprint_data data = {0}; + data.first_seen = bpf_ktime_get_ns(); + data.last_seen = data.first_seen; + data.packet_count = 1; + data.ttl = info->ttl; + data.window_size = bpf_ntohs(tcp->window); + parse_tcp_mss_wscale(tcp, data_end, &data.mss, &data.window_scale); + + bpf_map_update_elem(&tcp_fingerprints, &key, &data, BPF_NOEXIST); +} + +/* + * Record TCP fingerprint for monitoring (IPv6) + */ +static void record_tcp_fingerprint_v6(struct fp_record_info *info, + struct tcphdr *tcp, void *data_end) +{ + struct tcp_fingerprint_key_v6 key = {0}; + __builtin_memcpy(key.src_ip, info->src_ip_v6, 16); + key.src_port = info->src_port; + __builtin_memcpy(key.fingerprint, info->fingerprint, 14); + + // Only create new entries + if (bpf_map_lookup_elem(&tcp_fingerprints_v6, &key)) + return; + + struct tcp_fingerprint_data data = {0}; + data.first_seen = bpf_ktime_get_ns(); + data.last_seen = data.first_seen; + data.packet_count = 1; + data.ttl = info->ttl; + data.window_size = bpf_ntohs(tcp->window); + parse_tcp_mss_wscale(tcp, data_end, &data.mss, &data.window_scale); + + bpf_map_update_elem(&tcp_fingerprints_v6, &key, &data, BPF_NOEXIST); +} + +SEC("xdp") +int arxignis_xdp_filter(struct xdp_md *ctx) +{ + // This filter is designed to only block incoming traffic + // It should be attached only to ingress hooks, not egress + // The filtering logic below blocks packets based on source IP addresses + // + // IP Version Support: + // - Supports IPv4-only, IPv6-only, and hybrid (both) modes + // - Note: XDP requires IPv6 to be enabled at kernel level for attachment, + // even when processing only IPv4 packets. This is a kernel limitation. + // - The BPF program processes both IPv4 and IPv6 packets based on the + // ethernet protocol type (ETH_P_IP for IPv4, ETH_P_IPV6 for IPv6) + + void *data_end = (void *)(long)ctx->data_end; + void *cursor = (void *)(long)ctx->data; + + // Debug: Count all packets + __u32 zero = 0; + __u32 *packet_count = bpf_map_lookup_elem(&total_packets_processed, &zero); + if (packet_count) { + __sync_fetch_and_add(packet_count, 1); + } + + struct ethhdr *eth = parse_and_advance(&cursor, data_end, sizeof(*eth)); + if (!eth) + return XDP_PASS; + + __u16 h_proto = eth->h_proto; + + // Increment total packets processed counter + increment_total_packets_processed(); + + if (h_proto == bpf_htons(ETH_P_IP)) { + struct iphdr *iph = parse_and_advance(&cursor, data_end, sizeof(*iph)); + if (!iph) + return XDP_PASS; + + struct lpm_key key = { + .prefixlen = 32, + .addr = iph->saddr, + }; + + if (bpf_map_lookup_elem(&banned_ips, &key)) { + increment_ipv4_banned_stats(); + increment_total_packets_dropped(); + increment_dropped_ipv4_address(iph->saddr); + //bpf_printk("XDP: BLOCKED incoming permanently banned IPv4 %pI4", &iph->saddr); + return XDP_DROP; + } + + if (bpf_map_lookup_elem(&recently_banned_ips, &key)) { + increment_ipv4_recently_banned_stats(); + // Block UDP and ICMP from recently banned IPs, but allow DNS + if (iph->protocol == IPPROTO_UDP) { + struct udphdr *udph = parse_and_advance(&cursor, data_end, sizeof(*udph)); + if (udph && udph->dest == bpf_htons(53)) { + return XDP_PASS; // Allow DNS responses + } + // Block other UDP traffic + ip_flag_t one = 1; + bpf_map_update_elem(&banned_ips, &key, &one, BPF_ANY); + bpf_map_delete_elem(&recently_banned_ips, &key); + increment_total_packets_dropped(); + increment_dropped_ipv4_address(iph->saddr); + //bpf_printk("XDP: BLOCKED incoming UDP from recently banned IPv4 %pI4, promoted to permanent ban", &iph->saddr); + return XDP_DROP; + } + if (iph->protocol == IPPROTO_ICMP) { + ip_flag_t one = 1; + bpf_map_update_elem(&banned_ips, &key, &one, BPF_ANY); + bpf_map_delete_elem(&recently_banned_ips, &key); + increment_total_packets_dropped(); + increment_dropped_ipv4_address(iph->saddr); + //bpf_printk("XDP: BLOCKED incoming ICMP from recently banned IPv4 %pI4, promoted to permanent ban", &iph->saddr); + return XDP_DROP; + } + // For TCP, promote to banned on FIN/RST + if (iph->protocol == IPPROTO_TCP) { + struct tcphdr *tcph = parse_and_advance(&cursor, data_end, sizeof(*tcph)); + if (tcph && (tcph->fin || tcph->rst)) { + ip_flag_t one = 1; + bpf_map_update_elem(&banned_ips, &key, &one, BPF_ANY); + bpf_map_delete_elem(&recently_banned_ips, &key); + increment_total_packets_dropped(); + increment_dropped_ipv4_address(iph->saddr); + } + } + return XDP_PASS; + } + + // Perform TCP fingerprinting ONLY on SYN packets + if (iph->protocol == IPPROTO_TCP) { + struct tcphdr *tcph = parse_and_advance(&cursor, data_end, sizeof(*tcph)); + if (tcph) { + // Only fingerprint SYN packets (not SYN-ACK) to capture MSS/WSCALE + if (tcph->syn && !tcph->ack) { + increment_tcp_syn_stats(); + + // Generate fingerprint to check if blocked + __u8 fingerprint[14] = {0}; + generate_tcp_fingerprint(tcph, data_end, iph->ttl, fingerprint); + + // Check if this TCP fingerprint is blocked + if (is_tcp_fingerprint_blocked(fingerprint)) { + increment_tcp_fingerprint_blocks_ipv4(); + increment_total_packets_dropped(); + increment_dropped_ipv4_address(iph->saddr); + return XDP_DROP; + } + // Record fingerprint for monitoring + struct fp_record_info fp_info = {0}; + fp_info.src_ip_v4 = iph->saddr; + fp_info.src_port = tcph->source; + fp_info.ttl = iph->ttl; + __builtin_memcpy(fp_info.fingerprint, fingerprint, 14); + record_tcp_fingerprint_v4(&fp_info, tcph, data_end); + } + } + } + + return XDP_PASS; + } + else if (h_proto == bpf_htons(ETH_P_IPV6)) { + struct ipv6hdr *ip6h = parse_and_advance(&cursor, data_end, sizeof(*ip6h)); + if (!ip6h) + return XDP_PASS; + + // Always allow DNS traffic (UDP port 53) to pass through + if (ip6h->nexthdr == IPPROTO_UDP) { + struct udphdr *udph = parse_and_advance(&cursor, data_end, sizeof(*udph)); + if (udph && (udph->dest == bpf_htons(53) || udph->source == bpf_htons(53))) { + return XDP_PASS; // Always allow DNS traffic + } + } + + // Check banned/recently banned maps by source IPv6 + struct lpm_key_v6 key6 = { + .prefixlen = 128, + }; + __u8 *src_addr = (__u8 *)&ip6h->saddr; + __builtin_memcpy(key6.addr, src_addr, 16); + + if (bpf_map_lookup_elem(&banned_ips_v6, &key6)) { + increment_ipv6_banned_stats(); + increment_total_packets_dropped(); + increment_dropped_ipv6_address(ip6h->saddr); + //bpf_printk("XDP: BLOCKED incoming permanently banned IPv6"); + return XDP_DROP; + } + + if (bpf_map_lookup_elem(&recently_banned_ips_v6, &key6)) { + increment_ipv6_recently_banned_stats(); + // Block UDP and ICMP from recently banned IPv6 IPs, but allow DNS + if (ip6h->nexthdr == IPPROTO_UDP) { + struct udphdr *udph = parse_and_advance(&cursor, data_end, sizeof(*udph)); + if (udph && udph->dest == bpf_htons(53)) { + return XDP_PASS; // Allow DNS responses + } + // Block other UDP traffic + ip_flag_t one = 1; + bpf_map_update_elem(&banned_ips_v6, &key6, &one, BPF_ANY); + bpf_map_delete_elem(&recently_banned_ips_v6, &key6); + increment_total_packets_dropped(); + increment_dropped_ipv6_address(ip6h->saddr); + //bpf_printk("XDP: BLOCKED incoming UDP from recently banned IPv6, promoted to permanent ban"); + return XDP_DROP; + } + if (ip6h->nexthdr == 58) { // 58 = IPPROTO_ICMPV6 + ip_flag_t one = 1; + bpf_map_update_elem(&banned_ips_v6, &key6, &one, BPF_ANY); + bpf_map_delete_elem(&recently_banned_ips_v6, &key6); + increment_total_packets_dropped(); + increment_dropped_ipv6_address(ip6h->saddr); + //bpf_printk("XDP: BLOCKED incoming ICMPv6 from recently banned IPv6, promoted to permanent ban"); + return XDP_DROP; + } + // For TCP, only promote to banned on FIN/RST + if (ip6h->nexthdr == IPPROTO_TCP) { + struct tcphdr *tcph = parse_and_advance(&cursor, data_end, sizeof(*tcph)); + if (tcph) { + if (tcph->fin || tcph->rst) { + ip_flag_t one = 1; + bpf_map_update_elem(&banned_ips_v6, &key6, &one, BPF_ANY); + bpf_map_delete_elem(&recently_banned_ips_v6, &key6); + increment_total_packets_dropped(); + increment_dropped_ipv6_address(ip6h->saddr); + //bpf_printk("XDP: TCP FIN/RST from incoming recently banned IPv6, promoted to permanent ban"); + } + } + } + return XDP_PASS; // Allow if recently banned + } + + // Perform TCP fingerprinting on IPv6 TCP packets + if (ip6h->nexthdr == IPPROTO_TCP) { + struct tcphdr *tcph = parse_and_advance(&cursor, data_end, sizeof(*tcph)); + if (tcph) { + // Perform TCP fingerprinting ONLY on SYN packets (not SYN-ACK) + // This ensures we capture the initial handshake with MSS/WSCALE + if (tcph->syn && !tcph->ack) { + // Skip IPv6 localhost traffic (::1) - simplified check + __u8 *src_addr = (__u8 *)&ip6h->saddr; + // Quick localhost check: first 8 bytes zero, bytes 8-14 zero, byte 15 is 1 + if (src_addr[0] == 0 && src_addr[7] == 0 && + src_addr[8] == 0 && src_addr[14] == 0 && src_addr[15] == 1) { + return XDP_PASS; + } + + // Extract TTL from IPv6 hop limit + __u16 ttl = ip6h->hop_limit; + + // Generate fingerprint to check if blocked + __u8 fingerprint[14] = {0}; + generate_tcp_fingerprint(tcph, data_end, ttl, fingerprint); + + // Check if this TCP fingerprint is blocked + if (is_tcp_fingerprint_blocked_v6(fingerprint)) { + increment_tcp_fingerprint_blocks_ipv6(); + increment_total_packets_dropped(); + increment_dropped_ipv6_address(ip6h->saddr); + return XDP_DROP; + } + // Record fingerprint for monitoring + struct fp_record_info fp_info = {0}; + __builtin_memcpy(fp_info.src_ip_v6, src_addr, 16); + fp_info.src_port = tcph->source; + fp_info.ttl = ttl; + __builtin_memcpy(fp_info.fingerprint, fingerprint, 14); + record_tcp_fingerprint_v6(&fp_info, tcph, data_end); + } + } + } + + return XDP_PASS; + } + + return XDP_PASS; + // return XDP_ABORTED; +} + +char _license[] SEC("license") = "GPL"; diff --git a/src/bpf/filter.h b/src/bpf/filter.h index 747c731..bdb8784 100644 --- a/src/bpf/filter.h +++ b/src/bpf/filter.h @@ -1,11 +1,11 @@ -// SPDX-License-Identifier: GPL-2.0 -#pragma once - -// Maximum entries for IP maps -#ifndef CITADEL_IP_MAP_MAX -#define CITADEL_IP_MAP_MAX 65536 -#endif - -// Map value is a simple flag (present = 1) -typedef __u8 ip_flag_t; - +// SPDX-License-Identifier: GPL-2.0 +#pragma once + +// Maximum entries for IP maps +#ifndef CITADEL_IP_MAP_MAX +#define CITADEL_IP_MAP_MAX 65536 +#endif + +// Map value is a simple flag (present = 1) +typedef __u8 ip_flag_t; + diff --git a/src/bpf_stats.rs b/src/bpf_stats.rs index 6c3f88b..3b37b44 100644 --- a/src/bpf_stats.rs +++ b/src/bpf_stats.rs @@ -1,729 +1,729 @@ -use std::sync::Arc; -use serde::{Deserialize, Serialize}; -use chrono::{DateTime, Utc}; -use std::collections::HashMap; -use std::net::{Ipv4Addr, Ipv6Addr}; -use libbpf_rs::MapCore; -use crate::worker::log::{send_event, UnifiedEvent}; - -use crate::bpf::FilterSkel; - -/// BPF statistics collected from kernel-level access rule enforcement -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BpfAccessStats { - pub timestamp: DateTime, - pub total_packets_processed: u64, - pub total_packets_dropped: u64, - pub ipv4_banned_hits: u64, - pub ipv4_recently_banned_hits: u64, - pub ipv6_banned_hits: u64, - pub ipv6_recently_banned_hits: u64, - pub tcp_fingerprint_blocks_ipv4: u64, - pub tcp_fingerprint_blocks_ipv6: u64, - pub drop_rate_percentage: f64, - pub dropped_ip_addresses: DroppedIpAddresses, -} - -/// Statistics about dropped IP addresses -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DroppedIpAddresses { - pub ipv4_addresses: HashMap, // IP address -> drop count - pub ipv6_addresses: HashMap, // IP address -> drop count - pub total_unique_dropped_ips: u64, -} - -/// Individual event for a dropped IP address -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DroppedIpEvent { - pub event_type: String, - pub timestamp: DateTime, - pub ip_address: String, - pub ip_version: IpVersion, - pub drop_count: u64, - pub drop_reason: DropReason -} - -/// IP version enumeration -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum IpVersion { - IPv4, - IPv6, -} - -/// Reason for dropping packets -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum DropReason { - AccessRules, - RecentlyBannedUdp, - RecentlyBannedIcmp, - RecentlyBannedTcpFinRst, -} - -/// Collection of dropped IP events -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DroppedIpEvents { - pub timestamp: DateTime, - pub events: Vec, - pub total_events: u64, - pub unique_ips: u64, -} - -impl BpfAccessStats { - /// Create a new statistics snapshot from BPF maps - pub fn from_bpf_maps(skel: &FilterSkel) -> Result> { - let timestamp = Utc::now(); - - // Read statistics from BPF maps - let total_packets_processed = Self::read_bpf_counter(&skel.maps.total_packets_processed)?; - let total_packets_dropped = Self::read_bpf_counter(&skel.maps.total_packets_dropped)?; - let ipv4_banned_hits = Self::read_bpf_counter(&skel.maps.ipv4_banned_stats)?; - let ipv4_recently_banned_hits = Self::read_bpf_counter(&skel.maps.ipv4_recently_banned_stats)?; - let ipv6_banned_hits = Self::read_bpf_counter(&skel.maps.ipv6_banned_stats)?; - let ipv6_recently_banned_hits = Self::read_bpf_counter(&skel.maps.ipv6_recently_banned_stats)?; - let tcp_fingerprint_blocks_ipv4 = Self::read_bpf_counter(&skel.maps.tcp_fingerprint_blocks_ipv4)?; - let tcp_fingerprint_blocks_ipv6 = Self::read_bpf_counter(&skel.maps.tcp_fingerprint_blocks_ipv6)?; - - // Collect dropped IP addresses - let dropped_ip_addresses = Self::collect_dropped_ip_addresses(skel)?; - - // Calculate drop rate percentage - let drop_rate_percentage = if total_packets_processed > 0 { - (total_packets_dropped as f64 / total_packets_processed as f64) * 100.0 - } else { - 0.0 - }; - - Ok(BpfAccessStats { - timestamp, - total_packets_processed, - total_packets_dropped, - ipv4_banned_hits, - ipv4_recently_banned_hits, - ipv6_banned_hits, - ipv6_recently_banned_hits, - tcp_fingerprint_blocks_ipv4, - tcp_fingerprint_blocks_ipv6, - drop_rate_percentage, - dropped_ip_addresses, - }) - } - - /// Read a counter value from a BPF array map - fn read_bpf_counter(map: &impl libbpf_rs::MapCore) -> Result> { - let key = 0u32.to_le_bytes(); - if let Some(value_bytes) = map.lookup(&key, libbpf_rs::MapFlags::ANY)? { - if value_bytes.len() >= 8 { - let value = u64::from_le_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], - ]); - Ok(value) - } else { - Ok(0) - } - } else { - Ok(0) - } - } - - /// Collect dropped IP addresses from BPF maps - fn collect_dropped_ip_addresses(skel: &FilterSkel) -> Result> { - let mut ipv4_addresses = HashMap::new(); - let mut ipv6_addresses = HashMap::new(); - - log::debug!("Collecting dropped IP addresses from BPF maps"); - - // Try batch lookup first, fall back to keys iterator if empty - log::debug!("Reading IPv4 dropped addresses from BPF map"); - let mut count = 0; - - // First try lookup_batch - let mut batch_worked = false; - if let Ok(batch_iter) = skel.maps.dropped_ipv4_addresses.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { - for (key_bytes, value_bytes) in batch_iter { - batch_worked = true; - if key_bytes.len() >= 4 && value_bytes.len() >= 8 { - let ip_bytes = [key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3]]; - let ip_addr = Ipv4Addr::from(ip_bytes); - let drop_count = u64::from_le_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], - ]); - if drop_count > 0 { - log::debug!("Found dropped IPv4: {} (dropped {} times)", ip_addr, drop_count); - ipv4_addresses.insert(ip_addr.to_string(), drop_count); - count += 1; - } - } - } - } - - // If batch lookup returned nothing, try keys iterator as fallback - if !batch_worked { - log::debug!("Batch lookup empty, trying keys iterator for IPv4"); - for key_bytes in skel.maps.dropped_ipv4_addresses.keys() { - if key_bytes.len() >= 4 { - if let Ok(Some(value_bytes)) = skel.maps.dropped_ipv4_addresses.lookup(&key_bytes, libbpf_rs::MapFlags::ANY) { - if value_bytes.len() >= 8 { - let ip_bytes = [key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3]]; - let ip_addr = Ipv4Addr::from(ip_bytes); - let drop_count = u64::from_le_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], - ]); - if drop_count > 0 { - log::debug!("Found dropped IPv4 (via keys): {} (dropped {} times)", ip_addr, drop_count); - ipv4_addresses.insert(ip_addr.to_string(), drop_count); - count += 1; - } - } - } - } - } - } - log::debug!("Found {} dropped IPv4 addresses", count); - - // Read IPv6 addresses - try batch lookup first, fall back to keys iterator - log::debug!("Reading IPv6 dropped addresses from BPF map"); - let mut ipv6_count = 0; - - let mut batch_worked = false; - if let Ok(batch_iter) = skel.maps.dropped_ipv6_addresses.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { - for (key_bytes, value_bytes) in batch_iter { - batch_worked = true; - if key_bytes.len() >= 16 && value_bytes.len() >= 8 { - let mut ip_bytes = [0u8; 16]; - ip_bytes.copy_from_slice(&key_bytes[..16]); - let ip_addr = Ipv6Addr::from(ip_bytes); - let drop_count = u64::from_le_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], - ]); - if drop_count > 0 { - log::debug!("Found dropped IPv6: {} (dropped {} times)", ip_addr, drop_count); - ipv6_addresses.insert(ip_addr.to_string(), drop_count); - ipv6_count += 1; - } - } - } - } - - // If batch lookup returned nothing, try keys iterator - if !batch_worked { - log::debug!("Batch lookup empty, trying keys iterator for IPv6"); - for key_bytes in skel.maps.dropped_ipv6_addresses.keys() { - if key_bytes.len() >= 16 { - if let Ok(Some(value_bytes)) = skel.maps.dropped_ipv6_addresses.lookup(&key_bytes, libbpf_rs::MapFlags::ANY) { - if value_bytes.len() >= 8 { - let mut ip_bytes = [0u8; 16]; - ip_bytes.copy_from_slice(&key_bytes[..16]); - let ip_addr = Ipv6Addr::from(ip_bytes); - let drop_count = u64::from_le_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], - ]); - if drop_count > 0 { - log::debug!("Found dropped IPv6 (via keys): {} (dropped {} times)", ip_addr, drop_count); - ipv6_addresses.insert(ip_addr.to_string(), drop_count); - ipv6_count += 1; - } - } - } - } - } - } - log::debug!("Found {} dropped IPv6 addresses", ipv6_count); - - let total_unique_dropped_ips = ipv4_addresses.len() as u64 + ipv6_addresses.len() as u64; - log::debug!("Total dropped IP addresses found: {} (IPv4: {}, IPv6: {})", - total_unique_dropped_ips, ipv4_addresses.len(), ipv6_addresses.len()); - - Ok(DroppedIpAddresses { - ipv4_addresses, - ipv6_addresses, - total_unique_dropped_ips, - }) - } - - - /// Convert to JSON string - pub fn to_json(&self) -> Result { - serde_json::to_string(self) - } - - /// Create a summary string for logging - pub fn summary(&self) -> String { - let mut summary = format!( - "BPF Stats: {} packets processed, {} dropped ({:.2}%), IPv4 banned: {}, IPv4 recent: {}, IPv6 banned: {}, IPv6 recent: {}, TCP FP blocks (IPv4/IPv6): {}/{}", - self.total_packets_processed, - self.total_packets_dropped, - self.drop_rate_percentage, - self.ipv4_banned_hits, - self.ipv4_recently_banned_hits, - self.ipv6_banned_hits, - self.ipv6_recently_banned_hits, - self.tcp_fingerprint_blocks_ipv4, - self.tcp_fingerprint_blocks_ipv6 - ); - - // Add top dropped IP addresses if any - if !self.dropped_ip_addresses.ipv4_addresses.is_empty() || !self.dropped_ip_addresses.ipv6_addresses.is_empty() { - summary.push_str(&format!(", {} unique IPs dropped", self.dropped_ip_addresses.total_unique_dropped_ips)); - - // Show top 5 dropped IPv4 addresses - let mut ipv4_vec: Vec<_> = self.dropped_ip_addresses.ipv4_addresses.iter().collect(); - ipv4_vec.sort_by(|a, b| b.1.cmp(a.1)); - if !ipv4_vec.is_empty() { - summary.push_str(", Top IPv4 drops: "); - for (i, (ip, count)) in ipv4_vec.iter().take(5).enumerate() { - if i > 0 { summary.push_str(", "); } - summary.push_str(&format!("{}:{}", ip, count)); - } - } - - // Show top 5 dropped IPv6 addresses - let mut ipv6_vec: Vec<_> = self.dropped_ip_addresses.ipv6_addresses.iter().collect(); - ipv6_vec.sort_by(|a, b| b.1.cmp(a.1)); - if !ipv6_vec.is_empty() { - summary.push_str(", Top IPv6 drops: "); - for (i, (ip, count)) in ipv6_vec.iter().take(5).enumerate() { - if i > 0 { summary.push_str(", "); } - summary.push_str(&format!("{}:{}", ip, count)); - } - } - } - - summary - } -} - -impl DroppedIpEvent { - /// Create a new dropped IP event - pub fn new( - ip_address: String, - ip_version: IpVersion, - drop_count: u64, - drop_reason: DropReason, - ) -> Self { - let now = Utc::now(); - Self { - event_type: "dropped_ips".to_string(), - timestamp: now, - ip_address, - ip_version, - drop_count, - drop_reason - } - } - - /// Convert to JSON string - pub fn to_json(&self) -> Result { - serde_json::to_string(self) - } - - /// Create a summary string for logging - pub fn summary(&self) -> String { - format!( - "IP Drop Event: {} {} dropped {} times (reason: {:?})", - self.ip_address, - match self.ip_version { - IpVersion::IPv4 => "IPv4", - IpVersion::IPv6 => "IPv6", - }, - self.drop_count, - self.drop_reason - ) - } -} - -impl DroppedIpEvents { - /// Create a new collection of dropped IP events - pub fn new() -> Self { - Self { - timestamp: Utc::now(), - events: Vec::new(), - total_events: 0, - unique_ips: 0, - } - } - - /// Add a dropped IP event - pub fn add_event(&mut self, event: DroppedIpEvent) { - self.events.push(event); - self.total_events += 1; - } - - /// Convert to JSON string - pub fn to_json(&self) -> Result { - serde_json::to_string(self) - } - - /// Create a summary string for logging - pub fn summary(&self) -> String { - format!( - "Dropped IP Events: {} events from {} unique IPs", - self.total_events, - self.unique_ips - ) - } - - /// Get top dropped IPs by count - pub fn get_top_dropped_ips(&self, limit: usize) -> Vec { - let mut events = self.events.clone(); - events.sort_by(|a, b| b.drop_count.cmp(&a.drop_count)); - events.into_iter().take(limit).collect() - } -} - -/// Statistics collector for BPF access rules -#[derive(Clone)] -pub struct BpfStatsCollector { - skels: Vec>>, - enabled: bool, -} - -impl BpfStatsCollector { - /// Create a new statistics collector - pub fn new(skels: Vec>>, enabled: bool) -> Self { - Self { skels, enabled } - } - - /// Enable or disable statistics collection - pub fn set_enabled(&mut self, enabled: bool) { - self.enabled = enabled; - } - - /// Check if statistics collection is enabled - pub fn is_enabled(&self) -> bool { - self.enabled - } - - /// Collect statistics from all BPF skeletons - pub fn collect_stats(&self) -> Result, Box> { - if !self.enabled { - return Ok(vec![]); - } - - let mut stats = Vec::new(); - for skel in &self.skels { - match BpfAccessStats::from_bpf_maps(skel) { - Ok(stat) => stats.push(stat), - Err(e) => { - log::warn!("Failed to collect BPF stats from skeleton: {}", e); - } - } - } - Ok(stats) - } - - /// Collect aggregated statistics across all skeletons - pub fn collect_aggregated_stats(&self) -> Result> { - if !self.enabled { - return Err("Statistics collection is disabled".into()); - } - - let individual_stats = self.collect_stats()?; - if individual_stats.is_empty() { - return Err("No statistics available".into()); - } - - // Aggregate statistics across all skeletons - let mut aggregated = BpfAccessStats { - timestamp: Utc::now(), - total_packets_processed: 0, - total_packets_dropped: 0, - ipv4_banned_hits: 0, - ipv4_recently_banned_hits: 0, - ipv6_banned_hits: 0, - ipv6_recently_banned_hits: 0, - tcp_fingerprint_blocks_ipv4: 0, - tcp_fingerprint_blocks_ipv6: 0, - drop_rate_percentage: 0.0, - dropped_ip_addresses: DroppedIpAddresses { - ipv4_addresses: HashMap::new(), - ipv6_addresses: HashMap::new(), - total_unique_dropped_ips: 0, - }, - }; - - for stat in individual_stats { - aggregated.total_packets_processed += stat.total_packets_processed; - aggregated.total_packets_dropped += stat.total_packets_dropped; - aggregated.ipv4_banned_hits += stat.ipv4_banned_hits; - aggregated.ipv4_recently_banned_hits += stat.ipv4_recently_banned_hits; - aggregated.ipv6_banned_hits += stat.ipv6_banned_hits; - aggregated.ipv6_recently_banned_hits += stat.ipv6_recently_banned_hits; - aggregated.tcp_fingerprint_blocks_ipv4 += stat.tcp_fingerprint_blocks_ipv4; - aggregated.tcp_fingerprint_blocks_ipv6 += stat.tcp_fingerprint_blocks_ipv6; - - // Merge IP addresses - for (ip, count) in stat.dropped_ip_addresses.ipv4_addresses { - *aggregated.dropped_ip_addresses.ipv4_addresses.entry(ip).or_insert(0) += count; - } - for (ip, count) in stat.dropped_ip_addresses.ipv6_addresses { - *aggregated.dropped_ip_addresses.ipv6_addresses.entry(ip).or_insert(0) += count; - } - } - - // Update total unique dropped IPs count - aggregated.dropped_ip_addresses.total_unique_dropped_ips = - aggregated.dropped_ip_addresses.ipv4_addresses.len() as u64 + - aggregated.dropped_ip_addresses.ipv6_addresses.len() as u64; - - // Recalculate drop rate for aggregated data - aggregated.drop_rate_percentage = if aggregated.total_packets_processed > 0 { - (aggregated.total_packets_dropped as f64 / aggregated.total_packets_processed as f64) * 100.0 - } else { - 0.0 - }; - - Ok(aggregated) - } - - /// Log current statistics - pub fn log_stats(&self) -> Result<(), Box> { - if !self.enabled { - return Ok(()); - } - - match self.collect_aggregated_stats() { - Ok(stats) => { - // Output as JSON for structured logging - match stats.to_json() { - Ok(json) => { - log::info!("{}", json); - } - Err(e) => { - // Fallback to text summary if JSON serialization fails - log::warn!("Failed to serialize BPF stats to JSON: {}, using text summary", e); - log::info!("{}", stats.summary()); - } - } - Ok(()) - } - Err(e) => { - log::warn!("Failed to collect BPF statistics: {}", e); - Err(e) - } - } - } - - /// Collect dropped IP events from BPF maps - pub fn collect_dropped_ip_events(&self) -> Result> { - if !self.enabled { - return Ok(DroppedIpEvents::new()); - } - - let mut events = DroppedIpEvents::new(); - - for skel in &self.skels { - let dropped_ips = BpfAccessStats::collect_dropped_ip_addresses(skel)?; - - // Convert IPv4 addresses to events - for (ip_str, count) in dropped_ips.ipv4_addresses { - let event = DroppedIpEvent::new( - ip_str, - IpVersion::IPv4, - count, - DropReason::AccessRules, // Default reason, could be enhanced - ); - events.add_event(event); - } - - // Convert IPv6 addresses to events - for (ip_str, count) in dropped_ips.ipv6_addresses { - let event = DroppedIpEvent::new( - ip_str, - IpVersion::IPv6, - count, - DropReason::AccessRules, // Default reason, could be enhanced - ); - events.add_event(event); - } - } - - events.unique_ips = events.events.len() as u64; - Ok(events) - } - - /// Log dropped IP events - pub fn log_dropped_ip_events(&self) -> Result<(), Box> { - if !self.enabled { - return Ok(()); - } - - let events = self.collect_dropped_ip_events()?; - - if events.total_events > 0 { - log::debug!("{}", events.summary()); - - // Log top 5 dropped IPs - let top_ips = events.get_top_dropped_ips(5); - for event in top_ips { - log::debug!(" {}", event.summary()); - } - - // Log as JSON for structured logging - if let Ok(json) = events.to_json() { - log::debug!("Dropped IP Events JSON: {}", json); - } - - // Send events to unified queue - for event in events.events { - send_event(UnifiedEvent::DroppedIp(event)); - } - - // Reset the counters after logging - self.reset_dropped_ip_counters()?; - } else { - log::debug!("No dropped IP events found"); - } - - Ok(()) - } - - - /// Reset dropped IP address counters in BPF maps - pub fn reset_dropped_ip_counters(&self) -> Result<(), Box> { - if !self.enabled { - return Ok(()); - } - - log::debug!("Resetting dropped IP address counters"); - - for skel in &self.skels { - // Reset IPv4 counters - match skel.maps.dropped_ipv4_addresses.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { - Ok(batch_iter) => { - let mut reset_count = 0; - for (key_bytes, _) in batch_iter { - if key_bytes.len() >= 4 { - let zero_count = 0u64.to_le_bytes(); - if let Err(e) = skel.maps.dropped_ipv4_addresses.update(&key_bytes, &zero_count, libbpf_rs::MapFlags::ANY) { - log::warn!("Failed to reset IPv4 counter: {}", e); - } else { - reset_count += 1; - } - } - } - log::debug!("Reset {} IPv4 dropped IP counters", reset_count); - } - Err(e) => { - log::warn!("Failed to reset IPv4 counters: {}", e); - } - } - - // Reset IPv6 counters - match skel.maps.dropped_ipv6_addresses.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { - Ok(batch_iter) => { - let mut reset_count = 0; - for (key_bytes, _) in batch_iter { - if key_bytes.len() >= 16 { - let zero_count = 0u64.to_le_bytes(); - if let Err(e) = skel.maps.dropped_ipv6_addresses.update(&key_bytes, &zero_count, libbpf_rs::MapFlags::ANY) { - log::warn!("Failed to reset IPv6 counter: {}", e); - } else { - reset_count += 1; - } - } - } - log::debug!("Reset {} IPv6 dropped IP counters", reset_count); - } - Err(e) => { - log::warn!("Failed to reset IPv6 counters: {}", e); - } - } - } - - Ok(()) - } -} - -/// Configuration for BPF statistics collection -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BpfStatsConfig { - pub enabled: bool, - pub log_interval_secs: u64, -} - -impl Default for BpfStatsConfig { - fn default() -> Self { - Self { - enabled: true, - log_interval_secs: 60, // Log stats every minute - } - } -} - -impl BpfStatsConfig { - /// Create a new configuration - pub fn new(enabled: bool, log_interval_secs: u64) -> Self { - Self { - enabled, - log_interval_secs, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_bpf_stats_summary() { - let mut ipv4_addresses = HashMap::new(); - ipv4_addresses.insert("192.168.1.1".to_string(), 10); - ipv4_addresses.insert("10.0.0.1".to_string(), 5); - - let stats = BpfAccessStats { - timestamp: Utc::now(), - total_packets_processed: 1000, - total_packets_dropped: 50, - ipv4_banned_hits: 30, - ipv4_recently_banned_hits: 10, - ipv6_banned_hits: 5, - ipv6_recently_banned_hits: 5, - tcp_fingerprint_blocks_ipv4: 0, - tcp_fingerprint_blocks_ipv6: 0, - drop_rate_percentage: 5.0, - dropped_ip_addresses: DroppedIpAddresses { - ipv4_addresses, - ipv6_addresses: HashMap::new(), - total_unique_dropped_ips: 2, - }, - }; - - let summary = stats.summary(); - assert!(summary.contains("1000 packets processed")); - assert!(summary.contains("50 dropped")); - assert!(summary.contains("5.00%")); - assert!(summary.contains("2 unique IPs dropped")); - assert!(summary.contains("192.168.1.1:10")); - } - - #[test] - fn test_bpf_stats_json() { - let stats = BpfAccessStats { - timestamp: Utc::now(), - total_packets_processed: 100, - total_packets_dropped: 10, - ipv4_banned_hits: 5, - ipv4_recently_banned_hits: 3, - ipv6_banned_hits: 1, - ipv6_recently_banned_hits: 1, - tcp_fingerprint_blocks_ipv4: 0, - tcp_fingerprint_blocks_ipv6: 0, - drop_rate_percentage: 10.0, - dropped_ip_addresses: DroppedIpAddresses { - ipv4_addresses: HashMap::new(), - ipv6_addresses: HashMap::new(), - total_unique_dropped_ips: 0, - }, - }; - - let json = stats.to_json().unwrap(); - assert!(json.contains("total_packets_processed")); - assert!(json.contains("drop_rate_percentage")); - assert!(json.contains("dropped_ip_addresses")); - } -} +use std::sync::Arc; +use serde::{Deserialize, Serialize}; +use chrono::{DateTime, Utc}; +use std::collections::HashMap; +use std::net::{Ipv4Addr, Ipv6Addr}; +use libbpf_rs::MapCore; +use crate::worker::log::{send_event, UnifiedEvent}; + +use crate::bpf::FilterSkel; + +/// BPF statistics collected from kernel-level access rule enforcement +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BpfAccessStats { + pub timestamp: DateTime, + pub total_packets_processed: u64, + pub total_packets_dropped: u64, + pub ipv4_banned_hits: u64, + pub ipv4_recently_banned_hits: u64, + pub ipv6_banned_hits: u64, + pub ipv6_recently_banned_hits: u64, + pub tcp_fingerprint_blocks_ipv4: u64, + pub tcp_fingerprint_blocks_ipv6: u64, + pub drop_rate_percentage: f64, + pub dropped_ip_addresses: DroppedIpAddresses, +} + +/// Statistics about dropped IP addresses +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DroppedIpAddresses { + pub ipv4_addresses: HashMap, // IP address -> drop count + pub ipv6_addresses: HashMap, // IP address -> drop count + pub total_unique_dropped_ips: u64, +} + +/// Individual event for a dropped IP address +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DroppedIpEvent { + pub event_type: String, + pub timestamp: DateTime, + pub ip_address: String, + pub ip_version: IpVersion, + pub drop_count: u64, + pub drop_reason: DropReason +} + +/// IP version enumeration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum IpVersion { + IPv4, + IPv6, +} + +/// Reason for dropping packets +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum DropReason { + AccessRules, + RecentlyBannedUdp, + RecentlyBannedIcmp, + RecentlyBannedTcpFinRst, +} + +/// Collection of dropped IP events +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DroppedIpEvents { + pub timestamp: DateTime, + pub events: Vec, + pub total_events: u64, + pub unique_ips: u64, +} + +impl BpfAccessStats { + /// Create a new statistics snapshot from BPF maps + pub fn from_bpf_maps(skel: &FilterSkel) -> Result> { + let timestamp = Utc::now(); + + // Read statistics from BPF maps + let total_packets_processed = Self::read_bpf_counter(&skel.maps.total_packets_processed)?; + let total_packets_dropped = Self::read_bpf_counter(&skel.maps.total_packets_dropped)?; + let ipv4_banned_hits = Self::read_bpf_counter(&skel.maps.ipv4_banned_stats)?; + let ipv4_recently_banned_hits = Self::read_bpf_counter(&skel.maps.ipv4_recently_banned_stats)?; + let ipv6_banned_hits = Self::read_bpf_counter(&skel.maps.ipv6_banned_stats)?; + let ipv6_recently_banned_hits = Self::read_bpf_counter(&skel.maps.ipv6_recently_banned_stats)?; + let tcp_fingerprint_blocks_ipv4 = Self::read_bpf_counter(&skel.maps.tcp_fingerprint_blocks_ipv4)?; + let tcp_fingerprint_blocks_ipv6 = Self::read_bpf_counter(&skel.maps.tcp_fingerprint_blocks_ipv6)?; + + // Collect dropped IP addresses + let dropped_ip_addresses = Self::collect_dropped_ip_addresses(skel)?; + + // Calculate drop rate percentage + let drop_rate_percentage = if total_packets_processed > 0 { + (total_packets_dropped as f64 / total_packets_processed as f64) * 100.0 + } else { + 0.0 + }; + + Ok(BpfAccessStats { + timestamp, + total_packets_processed, + total_packets_dropped, + ipv4_banned_hits, + ipv4_recently_banned_hits, + ipv6_banned_hits, + ipv6_recently_banned_hits, + tcp_fingerprint_blocks_ipv4, + tcp_fingerprint_blocks_ipv6, + drop_rate_percentage, + dropped_ip_addresses, + }) + } + + /// Read a counter value from a BPF array map + fn read_bpf_counter(map: &impl libbpf_rs::MapCore) -> Result> { + let key = 0u32.to_le_bytes(); + if let Some(value_bytes) = map.lookup(&key, libbpf_rs::MapFlags::ANY)? { + if value_bytes.len() >= 8 { + let value = u64::from_le_bytes([ + value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], + value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], + ]); + Ok(value) + } else { + Ok(0) + } + } else { + Ok(0) + } + } + + /// Collect dropped IP addresses from BPF maps + fn collect_dropped_ip_addresses(skel: &FilterSkel) -> Result> { + let mut ipv4_addresses = HashMap::new(); + let mut ipv6_addresses = HashMap::new(); + + log::debug!("Collecting dropped IP addresses from BPF maps"); + + // Try batch lookup first, fall back to keys iterator if empty + log::debug!("Reading IPv4 dropped addresses from BPF map"); + let mut count = 0; + + // First try lookup_batch + let mut batch_worked = false; + if let Ok(batch_iter) = skel.maps.dropped_ipv4_addresses.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + for (key_bytes, value_bytes) in batch_iter { + batch_worked = true; + if key_bytes.len() >= 4 && value_bytes.len() >= 8 { + let ip_bytes = [key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3]]; + let ip_addr = Ipv4Addr::from(ip_bytes); + let drop_count = u64::from_le_bytes([ + value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], + value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], + ]); + if drop_count > 0 { + log::debug!("Found dropped IPv4: {} (dropped {} times)", ip_addr, drop_count); + ipv4_addresses.insert(ip_addr.to_string(), drop_count); + count += 1; + } + } + } + } + + // If batch lookup returned nothing, try keys iterator as fallback + if !batch_worked { + log::debug!("Batch lookup empty, trying keys iterator for IPv4"); + for key_bytes in skel.maps.dropped_ipv4_addresses.keys() { + if key_bytes.len() >= 4 { + if let Ok(Some(value_bytes)) = skel.maps.dropped_ipv4_addresses.lookup(&key_bytes, libbpf_rs::MapFlags::ANY) { + if value_bytes.len() >= 8 { + let ip_bytes = [key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3]]; + let ip_addr = Ipv4Addr::from(ip_bytes); + let drop_count = u64::from_le_bytes([ + value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], + value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], + ]); + if drop_count > 0 { + log::debug!("Found dropped IPv4 (via keys): {} (dropped {} times)", ip_addr, drop_count); + ipv4_addresses.insert(ip_addr.to_string(), drop_count); + count += 1; + } + } + } + } + } + } + log::debug!("Found {} dropped IPv4 addresses", count); + + // Read IPv6 addresses - try batch lookup first, fall back to keys iterator + log::debug!("Reading IPv6 dropped addresses from BPF map"); + let mut ipv6_count = 0; + + let mut batch_worked = false; + if let Ok(batch_iter) = skel.maps.dropped_ipv6_addresses.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + for (key_bytes, value_bytes) in batch_iter { + batch_worked = true; + if key_bytes.len() >= 16 && value_bytes.len() >= 8 { + let mut ip_bytes = [0u8; 16]; + ip_bytes.copy_from_slice(&key_bytes[..16]); + let ip_addr = Ipv6Addr::from(ip_bytes); + let drop_count = u64::from_le_bytes([ + value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], + value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], + ]); + if drop_count > 0 { + log::debug!("Found dropped IPv6: {} (dropped {} times)", ip_addr, drop_count); + ipv6_addresses.insert(ip_addr.to_string(), drop_count); + ipv6_count += 1; + } + } + } + } + + // If batch lookup returned nothing, try keys iterator + if !batch_worked { + log::debug!("Batch lookup empty, trying keys iterator for IPv6"); + for key_bytes in skel.maps.dropped_ipv6_addresses.keys() { + if key_bytes.len() >= 16 { + if let Ok(Some(value_bytes)) = skel.maps.dropped_ipv6_addresses.lookup(&key_bytes, libbpf_rs::MapFlags::ANY) { + if value_bytes.len() >= 8 { + let mut ip_bytes = [0u8; 16]; + ip_bytes.copy_from_slice(&key_bytes[..16]); + let ip_addr = Ipv6Addr::from(ip_bytes); + let drop_count = u64::from_le_bytes([ + value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], + value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], + ]); + if drop_count > 0 { + log::debug!("Found dropped IPv6 (via keys): {} (dropped {} times)", ip_addr, drop_count); + ipv6_addresses.insert(ip_addr.to_string(), drop_count); + ipv6_count += 1; + } + } + } + } + } + } + log::debug!("Found {} dropped IPv6 addresses", ipv6_count); + + let total_unique_dropped_ips = ipv4_addresses.len() as u64 + ipv6_addresses.len() as u64; + log::debug!("Total dropped IP addresses found: {} (IPv4: {}, IPv6: {})", + total_unique_dropped_ips, ipv4_addresses.len(), ipv6_addresses.len()); + + Ok(DroppedIpAddresses { + ipv4_addresses, + ipv6_addresses, + total_unique_dropped_ips, + }) + } + + + /// Convert to JSON string + pub fn to_json(&self) -> Result { + serde_json::to_string(self) + } + + /// Create a summary string for logging + pub fn summary(&self) -> String { + let mut summary = format!( + "BPF Stats: {} packets processed, {} dropped ({:.2}%), IPv4 banned: {}, IPv4 recent: {}, IPv6 banned: {}, IPv6 recent: {}, TCP FP blocks (IPv4/IPv6): {}/{}", + self.total_packets_processed, + self.total_packets_dropped, + self.drop_rate_percentage, + self.ipv4_banned_hits, + self.ipv4_recently_banned_hits, + self.ipv6_banned_hits, + self.ipv6_recently_banned_hits, + self.tcp_fingerprint_blocks_ipv4, + self.tcp_fingerprint_blocks_ipv6 + ); + + // Add top dropped IP addresses if any + if !self.dropped_ip_addresses.ipv4_addresses.is_empty() || !self.dropped_ip_addresses.ipv6_addresses.is_empty() { + summary.push_str(&format!(", {} unique IPs dropped", self.dropped_ip_addresses.total_unique_dropped_ips)); + + // Show top 5 dropped IPv4 addresses + let mut ipv4_vec: Vec<_> = self.dropped_ip_addresses.ipv4_addresses.iter().collect(); + ipv4_vec.sort_by(|a, b| b.1.cmp(a.1)); + if !ipv4_vec.is_empty() { + summary.push_str(", Top IPv4 drops: "); + for (i, (ip, count)) in ipv4_vec.iter().take(5).enumerate() { + if i > 0 { summary.push_str(", "); } + summary.push_str(&format!("{}:{}", ip, count)); + } + } + + // Show top 5 dropped IPv6 addresses + let mut ipv6_vec: Vec<_> = self.dropped_ip_addresses.ipv6_addresses.iter().collect(); + ipv6_vec.sort_by(|a, b| b.1.cmp(a.1)); + if !ipv6_vec.is_empty() { + summary.push_str(", Top IPv6 drops: "); + for (i, (ip, count)) in ipv6_vec.iter().take(5).enumerate() { + if i > 0 { summary.push_str(", "); } + summary.push_str(&format!("{}:{}", ip, count)); + } + } + } + + summary + } +} + +impl DroppedIpEvent { + /// Create a new dropped IP event + pub fn new( + ip_address: String, + ip_version: IpVersion, + drop_count: u64, + drop_reason: DropReason, + ) -> Self { + let now = Utc::now(); + Self { + event_type: "dropped_ips".to_string(), + timestamp: now, + ip_address, + ip_version, + drop_count, + drop_reason + } + } + + /// Convert to JSON string + pub fn to_json(&self) -> Result { + serde_json::to_string(self) + } + + /// Create a summary string for logging + pub fn summary(&self) -> String { + format!( + "IP Drop Event: {} {} dropped {} times (reason: {:?})", + self.ip_address, + match self.ip_version { + IpVersion::IPv4 => "IPv4", + IpVersion::IPv6 => "IPv6", + }, + self.drop_count, + self.drop_reason + ) + } +} + +impl DroppedIpEvents { + /// Create a new collection of dropped IP events + pub fn new() -> Self { + Self { + timestamp: Utc::now(), + events: Vec::new(), + total_events: 0, + unique_ips: 0, + } + } + + /// Add a dropped IP event + pub fn add_event(&mut self, event: DroppedIpEvent) { + self.events.push(event); + self.total_events += 1; + } + + /// Convert to JSON string + pub fn to_json(&self) -> Result { + serde_json::to_string(self) + } + + /// Create a summary string for logging + pub fn summary(&self) -> String { + format!( + "Dropped IP Events: {} events from {} unique IPs", + self.total_events, + self.unique_ips + ) + } + + /// Get top dropped IPs by count + pub fn get_top_dropped_ips(&self, limit: usize) -> Vec { + let mut events = self.events.clone(); + events.sort_by(|a, b| b.drop_count.cmp(&a.drop_count)); + events.into_iter().take(limit).collect() + } +} + +/// Statistics collector for BPF access rules +#[derive(Clone)] +pub struct BpfStatsCollector { + skels: Vec>>, + enabled: bool, +} + +impl BpfStatsCollector { + /// Create a new statistics collector + pub fn new(skels: Vec>>, enabled: bool) -> Self { + Self { skels, enabled } + } + + /// Enable or disable statistics collection + pub fn set_enabled(&mut self, enabled: bool) { + self.enabled = enabled; + } + + /// Check if statistics collection is enabled + pub fn is_enabled(&self) -> bool { + self.enabled + } + + /// Collect statistics from all BPF skeletons + pub fn collect_stats(&self) -> Result, Box> { + if !self.enabled { + return Ok(vec![]); + } + + let mut stats = Vec::new(); + for skel in &self.skels { + match BpfAccessStats::from_bpf_maps(skel) { + Ok(stat) => stats.push(stat), + Err(e) => { + log::warn!("Failed to collect BPF stats from skeleton: {}", e); + } + } + } + Ok(stats) + } + + /// Collect aggregated statistics across all skeletons + pub fn collect_aggregated_stats(&self) -> Result> { + if !self.enabled { + return Err("Statistics collection is disabled".into()); + } + + let individual_stats = self.collect_stats()?; + if individual_stats.is_empty() { + return Err("No statistics available".into()); + } + + // Aggregate statistics across all skeletons + let mut aggregated = BpfAccessStats { + timestamp: Utc::now(), + total_packets_processed: 0, + total_packets_dropped: 0, + ipv4_banned_hits: 0, + ipv4_recently_banned_hits: 0, + ipv6_banned_hits: 0, + ipv6_recently_banned_hits: 0, + tcp_fingerprint_blocks_ipv4: 0, + tcp_fingerprint_blocks_ipv6: 0, + drop_rate_percentage: 0.0, + dropped_ip_addresses: DroppedIpAddresses { + ipv4_addresses: HashMap::new(), + ipv6_addresses: HashMap::new(), + total_unique_dropped_ips: 0, + }, + }; + + for stat in individual_stats { + aggregated.total_packets_processed += stat.total_packets_processed; + aggregated.total_packets_dropped += stat.total_packets_dropped; + aggregated.ipv4_banned_hits += stat.ipv4_banned_hits; + aggregated.ipv4_recently_banned_hits += stat.ipv4_recently_banned_hits; + aggregated.ipv6_banned_hits += stat.ipv6_banned_hits; + aggregated.ipv6_recently_banned_hits += stat.ipv6_recently_banned_hits; + aggregated.tcp_fingerprint_blocks_ipv4 += stat.tcp_fingerprint_blocks_ipv4; + aggregated.tcp_fingerprint_blocks_ipv6 += stat.tcp_fingerprint_blocks_ipv6; + + // Merge IP addresses + for (ip, count) in stat.dropped_ip_addresses.ipv4_addresses { + *aggregated.dropped_ip_addresses.ipv4_addresses.entry(ip).or_insert(0) += count; + } + for (ip, count) in stat.dropped_ip_addresses.ipv6_addresses { + *aggregated.dropped_ip_addresses.ipv6_addresses.entry(ip).or_insert(0) += count; + } + } + + // Update total unique dropped IPs count + aggregated.dropped_ip_addresses.total_unique_dropped_ips = + aggregated.dropped_ip_addresses.ipv4_addresses.len() as u64 + + aggregated.dropped_ip_addresses.ipv6_addresses.len() as u64; + + // Recalculate drop rate for aggregated data + aggregated.drop_rate_percentage = if aggregated.total_packets_processed > 0 { + (aggregated.total_packets_dropped as f64 / aggregated.total_packets_processed as f64) * 100.0 + } else { + 0.0 + }; + + Ok(aggregated) + } + + /// Log current statistics + pub fn log_stats(&self) -> Result<(), Box> { + if !self.enabled { + return Ok(()); + } + + match self.collect_aggregated_stats() { + Ok(stats) => { + // Output as JSON for structured logging + match stats.to_json() { + Ok(json) => { + log::info!("{}", json); + } + Err(e) => { + // Fallback to text summary if JSON serialization fails + log::warn!("Failed to serialize BPF stats to JSON: {}, using text summary", e); + log::info!("{}", stats.summary()); + } + } + Ok(()) + } + Err(e) => { + log::warn!("Failed to collect BPF statistics: {}", e); + Err(e) + } + } + } + + /// Collect dropped IP events from BPF maps + pub fn collect_dropped_ip_events(&self) -> Result> { + if !self.enabled { + return Ok(DroppedIpEvents::new()); + } + + let mut events = DroppedIpEvents::new(); + + for skel in &self.skels { + let dropped_ips = BpfAccessStats::collect_dropped_ip_addresses(skel)?; + + // Convert IPv4 addresses to events + for (ip_str, count) in dropped_ips.ipv4_addresses { + let event = DroppedIpEvent::new( + ip_str, + IpVersion::IPv4, + count, + DropReason::AccessRules, // Default reason, could be enhanced + ); + events.add_event(event); + } + + // Convert IPv6 addresses to events + for (ip_str, count) in dropped_ips.ipv6_addresses { + let event = DroppedIpEvent::new( + ip_str, + IpVersion::IPv6, + count, + DropReason::AccessRules, // Default reason, could be enhanced + ); + events.add_event(event); + } + } + + events.unique_ips = events.events.len() as u64; + Ok(events) + } + + /// Log dropped IP events + pub fn log_dropped_ip_events(&self) -> Result<(), Box> { + if !self.enabled { + return Ok(()); + } + + let events = self.collect_dropped_ip_events()?; + + if events.total_events > 0 { + log::debug!("{}", events.summary()); + + // Log top 5 dropped IPs + let top_ips = events.get_top_dropped_ips(5); + for event in top_ips { + log::debug!(" {}", event.summary()); + } + + // Log as JSON for structured logging + if let Ok(json) = events.to_json() { + log::debug!("Dropped IP Events JSON: {}", json); + } + + // Send events to unified queue + for event in events.events { + send_event(UnifiedEvent::DroppedIp(event)); + } + + // Reset the counters after logging + self.reset_dropped_ip_counters()?; + } else { + log::debug!("No dropped IP events found"); + } + + Ok(()) + } + + + /// Reset dropped IP address counters in BPF maps + pub fn reset_dropped_ip_counters(&self) -> Result<(), Box> { + if !self.enabled { + return Ok(()); + } + + log::debug!("Resetting dropped IP address counters"); + + for skel in &self.skels { + // Reset IPv4 counters + match skel.maps.dropped_ipv4_addresses.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + Ok(batch_iter) => { + let mut reset_count = 0; + for (key_bytes, _) in batch_iter { + if key_bytes.len() >= 4 { + let zero_count = 0u64.to_le_bytes(); + if let Err(e) = skel.maps.dropped_ipv4_addresses.update(&key_bytes, &zero_count, libbpf_rs::MapFlags::ANY) { + log::warn!("Failed to reset IPv4 counter: {}", e); + } else { + reset_count += 1; + } + } + } + log::debug!("Reset {} IPv4 dropped IP counters", reset_count); + } + Err(e) => { + log::warn!("Failed to reset IPv4 counters: {}", e); + } + } + + // Reset IPv6 counters + match skel.maps.dropped_ipv6_addresses.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + Ok(batch_iter) => { + let mut reset_count = 0; + for (key_bytes, _) in batch_iter { + if key_bytes.len() >= 16 { + let zero_count = 0u64.to_le_bytes(); + if let Err(e) = skel.maps.dropped_ipv6_addresses.update(&key_bytes, &zero_count, libbpf_rs::MapFlags::ANY) { + log::warn!("Failed to reset IPv6 counter: {}", e); + } else { + reset_count += 1; + } + } + } + log::debug!("Reset {} IPv6 dropped IP counters", reset_count); + } + Err(e) => { + log::warn!("Failed to reset IPv6 counters: {}", e); + } + } + } + + Ok(()) + } +} + +/// Configuration for BPF statistics collection +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BpfStatsConfig { + pub enabled: bool, + pub log_interval_secs: u64, +} + +impl Default for BpfStatsConfig { + fn default() -> Self { + Self { + enabled: true, + log_interval_secs: 60, // Log stats every minute + } + } +} + +impl BpfStatsConfig { + /// Create a new configuration + pub fn new(enabled: bool, log_interval_secs: u64) -> Self { + Self { + enabled, + log_interval_secs, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bpf_stats_summary() { + let mut ipv4_addresses = HashMap::new(); + ipv4_addresses.insert("192.168.1.1".to_string(), 10); + ipv4_addresses.insert("10.0.0.1".to_string(), 5); + + let stats = BpfAccessStats { + timestamp: Utc::now(), + total_packets_processed: 1000, + total_packets_dropped: 50, + ipv4_banned_hits: 30, + ipv4_recently_banned_hits: 10, + ipv6_banned_hits: 5, + ipv6_recently_banned_hits: 5, + tcp_fingerprint_blocks_ipv4: 0, + tcp_fingerprint_blocks_ipv6: 0, + drop_rate_percentage: 5.0, + dropped_ip_addresses: DroppedIpAddresses { + ipv4_addresses, + ipv6_addresses: HashMap::new(), + total_unique_dropped_ips: 2, + }, + }; + + let summary = stats.summary(); + assert!(summary.contains("1000 packets processed")); + assert!(summary.contains("50 dropped")); + assert!(summary.contains("5.00%")); + assert!(summary.contains("2 unique IPs dropped")); + assert!(summary.contains("192.168.1.1:10")); + } + + #[test] + fn test_bpf_stats_json() { + let stats = BpfAccessStats { + timestamp: Utc::now(), + total_packets_processed: 100, + total_packets_dropped: 10, + ipv4_banned_hits: 5, + ipv4_recently_banned_hits: 3, + ipv6_banned_hits: 1, + ipv6_recently_banned_hits: 1, + tcp_fingerprint_blocks_ipv4: 0, + tcp_fingerprint_blocks_ipv6: 0, + drop_rate_percentage: 10.0, + dropped_ip_addresses: DroppedIpAddresses { + ipv4_addresses: HashMap::new(), + ipv6_addresses: HashMap::new(), + total_unique_dropped_ips: 0, + }, + }; + + let json = stats.to_json().unwrap(); + assert!(json.contains("total_packets_processed")); + assert!(json.contains("drop_rate_percentage")); + assert!(json.contains("dropped_ip_addresses")); + } +} diff --git a/src/bpf_stats_noop.rs b/src/bpf_stats_noop.rs index db34e09..d483958 100644 --- a/src/bpf_stats_noop.rs +++ b/src/bpf_stats_noop.rs @@ -1,236 +1,236 @@ -use std::collections::HashMap; -use std::sync::Arc; - -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; - -use crate::bpf::FilterSkel; - -/// BPF statistics collected from kernel-level access rule enforcement -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BpfAccessStats { - pub timestamp: DateTime, - pub total_packets_processed: u64, - pub total_packets_dropped: u64, - pub ipv4_banned_hits: u64, - pub ipv4_recently_banned_hits: u64, - pub ipv6_banned_hits: u64, - pub ipv6_recently_banned_hits: u64, - pub tcp_fingerprint_blocks_ipv4: u64, - pub tcp_fingerprint_blocks_ipv6: u64, - pub drop_rate_percentage: f64, - pub dropped_ip_addresses: DroppedIpAddresses, -} - -/// Statistics about dropped IP addresses -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DroppedIpAddresses { - pub ipv4_addresses: HashMap, - pub ipv6_addresses: HashMap, - pub total_unique_dropped_ips: u64, -} - -/// Individual event for a dropped IP address -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DroppedIpEvent { - pub event_type: String, - pub timestamp: DateTime, - pub ip_address: String, - pub ip_version: IpVersion, - pub drop_count: u64, - pub drop_reason: DropReason, -} - -/// IP version enumeration -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum IpVersion { - IPv4, - IPv6, -} - -/// Reason for dropping packets -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum DropReason { - AccessRules, - RecentlyBannedUdp, - RecentlyBannedIcmp, - RecentlyBannedTcpFinRst, -} - -/// Collection of dropped IP events -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DroppedIpEvents { - pub timestamp: DateTime, - pub events: Vec, - pub total_events: u64, - pub unique_ips: u64, -} - -impl BpfAccessStats { - pub fn empty() -> Self { - Self { - timestamp: Utc::now(), - total_packets_processed: 0, - total_packets_dropped: 0, - ipv4_banned_hits: 0, - ipv4_recently_banned_hits: 0, - ipv6_banned_hits: 0, - ipv6_recently_banned_hits: 0, - tcp_fingerprint_blocks_ipv4: 0, - tcp_fingerprint_blocks_ipv6: 0, - drop_rate_percentage: 0.0, - dropped_ip_addresses: DroppedIpAddresses { - ipv4_addresses: HashMap::new(), - ipv6_addresses: HashMap::new(), - total_unique_dropped_ips: 0, - }, - } - } - - /// Create a summary string for logging - pub fn summary(&self) -> String { - format!( - "BPF Stats: {} packets processed, {} dropped ({:.2}%), {} unique IPs dropped", - self.total_packets_processed, - self.total_packets_dropped, - self.drop_rate_percentage, - self.dropped_ip_addresses.total_unique_dropped_ips - ) - } -} - -impl DroppedIpEvent { - pub fn new( - ip_address: String, - ip_version: IpVersion, - drop_count: u64, - drop_reason: DropReason, - ) -> Self { - Self { - event_type: "dropped_ips".to_string(), - timestamp: Utc::now(), - ip_address, - ip_version, - drop_count, - drop_reason, - } - } - - pub fn to_json(&self) -> Result { - serde_json::to_string(self) - } - - pub fn summary(&self) -> String { - format!( - "IP Drop Event: {} {:?} dropped {} times (reason: {:?})", - self.ip_address, - self.ip_version, - self.drop_count, - self.drop_reason - ) - } -} - -impl DroppedIpEvents { - pub fn new() -> Self { - Self { - timestamp: Utc::now(), - events: Vec::new(), - total_events: 0, - unique_ips: 0, - } - } - - pub fn add_event(&mut self, event: DroppedIpEvent) { - self.events.push(event); - self.total_events += 1; - } - - pub fn to_json(&self) -> Result { - serde_json::to_string(self) - } - - pub fn summary(&self) -> String { - format!( - "Dropped IP Events: {} events from {} unique IPs", - self.total_events, - self.unique_ips - ) - } - - pub fn get_top_dropped_ips(&self, limit: usize) -> Vec { - let mut events = self.events.clone(); - events.sort_by(|a, b| b.drop_count.cmp(&a.drop_count)); - events.into_iter().take(limit).collect() - } -} - -/// Statistics collector for BPF access rules -#[derive(Clone)] -pub struct BpfStatsCollector { - _skels: Vec>>, - enabled: bool, -} - -impl BpfStatsCollector { - pub fn new(skels: Vec>>, enabled: bool) -> Self { - Self { _skels: skels, enabled } - } - - pub fn set_enabled(&mut self, enabled: bool) { - self.enabled = enabled; - } - - pub fn is_enabled(&self) -> bool { - self.enabled - } - - pub fn collect_stats(&self) -> Result, Box> { - Ok(Vec::new()) - } - - pub fn collect_aggregated_stats(&self) -> Result> { - Ok(BpfAccessStats::empty()) - } - - pub fn log_stats(&self) -> Result<(), Box> { - Ok(()) - } - - pub fn collect_dropped_ip_events(&self) -> Result> { - Ok(DroppedIpEvents::new()) - } - - pub fn log_dropped_ip_events(&self) -> Result<(), Box> { - Ok(()) - } - - pub fn reset_dropped_ip_counters(&self) -> Result<(), Box> { - Ok(()) - } -} - -/// Configuration for BPF statistics collection -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BpfStatsConfig { - pub enabled: bool, - pub log_interval_secs: u64, -} - -impl Default for BpfStatsConfig { - fn default() -> Self { - Self { - enabled: true, - log_interval_secs: 60, - } - } -} - -impl BpfStatsConfig { - pub fn new(enabled: bool, log_interval_secs: u64) -> Self { - Self { - enabled, - log_interval_secs, - } - } -} +use std::collections::HashMap; +use std::sync::Arc; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use crate::bpf::FilterSkel; + +/// BPF statistics collected from kernel-level access rule enforcement +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BpfAccessStats { + pub timestamp: DateTime, + pub total_packets_processed: u64, + pub total_packets_dropped: u64, + pub ipv4_banned_hits: u64, + pub ipv4_recently_banned_hits: u64, + pub ipv6_banned_hits: u64, + pub ipv6_recently_banned_hits: u64, + pub tcp_fingerprint_blocks_ipv4: u64, + pub tcp_fingerprint_blocks_ipv6: u64, + pub drop_rate_percentage: f64, + pub dropped_ip_addresses: DroppedIpAddresses, +} + +/// Statistics about dropped IP addresses +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DroppedIpAddresses { + pub ipv4_addresses: HashMap, + pub ipv6_addresses: HashMap, + pub total_unique_dropped_ips: u64, +} + +/// Individual event for a dropped IP address +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DroppedIpEvent { + pub event_type: String, + pub timestamp: DateTime, + pub ip_address: String, + pub ip_version: IpVersion, + pub drop_count: u64, + pub drop_reason: DropReason, +} + +/// IP version enumeration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum IpVersion { + IPv4, + IPv6, +} + +/// Reason for dropping packets +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum DropReason { + AccessRules, + RecentlyBannedUdp, + RecentlyBannedIcmp, + RecentlyBannedTcpFinRst, +} + +/// Collection of dropped IP events +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DroppedIpEvents { + pub timestamp: DateTime, + pub events: Vec, + pub total_events: u64, + pub unique_ips: u64, +} + +impl BpfAccessStats { + pub fn empty() -> Self { + Self { + timestamp: Utc::now(), + total_packets_processed: 0, + total_packets_dropped: 0, + ipv4_banned_hits: 0, + ipv4_recently_banned_hits: 0, + ipv6_banned_hits: 0, + ipv6_recently_banned_hits: 0, + tcp_fingerprint_blocks_ipv4: 0, + tcp_fingerprint_blocks_ipv6: 0, + drop_rate_percentage: 0.0, + dropped_ip_addresses: DroppedIpAddresses { + ipv4_addresses: HashMap::new(), + ipv6_addresses: HashMap::new(), + total_unique_dropped_ips: 0, + }, + } + } + + /// Create a summary string for logging + pub fn summary(&self) -> String { + format!( + "BPF Stats: {} packets processed, {} dropped ({:.2}%), {} unique IPs dropped", + self.total_packets_processed, + self.total_packets_dropped, + self.drop_rate_percentage, + self.dropped_ip_addresses.total_unique_dropped_ips + ) + } +} + +impl DroppedIpEvent { + pub fn new( + ip_address: String, + ip_version: IpVersion, + drop_count: u64, + drop_reason: DropReason, + ) -> Self { + Self { + event_type: "dropped_ips".to_string(), + timestamp: Utc::now(), + ip_address, + ip_version, + drop_count, + drop_reason, + } + } + + pub fn to_json(&self) -> Result { + serde_json::to_string(self) + } + + pub fn summary(&self) -> String { + format!( + "IP Drop Event: {} {:?} dropped {} times (reason: {:?})", + self.ip_address, + self.ip_version, + self.drop_count, + self.drop_reason + ) + } +} + +impl DroppedIpEvents { + pub fn new() -> Self { + Self { + timestamp: Utc::now(), + events: Vec::new(), + total_events: 0, + unique_ips: 0, + } + } + + pub fn add_event(&mut self, event: DroppedIpEvent) { + self.events.push(event); + self.total_events += 1; + } + + pub fn to_json(&self) -> Result { + serde_json::to_string(self) + } + + pub fn summary(&self) -> String { + format!( + "Dropped IP Events: {} events from {} unique IPs", + self.total_events, + self.unique_ips + ) + } + + pub fn get_top_dropped_ips(&self, limit: usize) -> Vec { + let mut events = self.events.clone(); + events.sort_by(|a, b| b.drop_count.cmp(&a.drop_count)); + events.into_iter().take(limit).collect() + } +} + +/// Statistics collector for BPF access rules +#[derive(Clone)] +pub struct BpfStatsCollector { + _skels: Vec>>, + enabled: bool, +} + +impl BpfStatsCollector { + pub fn new(skels: Vec>>, enabled: bool) -> Self { + Self { _skels: skels, enabled } + } + + pub fn set_enabled(&mut self, enabled: bool) { + self.enabled = enabled; + } + + pub fn is_enabled(&self) -> bool { + self.enabled + } + + pub fn collect_stats(&self) -> Result, Box> { + Ok(Vec::new()) + } + + pub fn collect_aggregated_stats(&self) -> Result> { + Ok(BpfAccessStats::empty()) + } + + pub fn log_stats(&self) -> Result<(), Box> { + Ok(()) + } + + pub fn collect_dropped_ip_events(&self) -> Result> { + Ok(DroppedIpEvents::new()) + } + + pub fn log_dropped_ip_events(&self) -> Result<(), Box> { + Ok(()) + } + + pub fn reset_dropped_ip_counters(&self) -> Result<(), Box> { + Ok(()) + } +} + +/// Configuration for BPF statistics collection +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BpfStatsConfig { + pub enabled: bool, + pub log_interval_secs: u64, +} + +impl Default for BpfStatsConfig { + fn default() -> Self { + Self { + enabled: true, + log_interval_secs: 60, + } + } +} + +impl BpfStatsConfig { + pub fn new(enabled: bool, log_interval_secs: u64) -> Self { + Self { + enabled, + log_interval_secs, + } + } +} diff --git a/src/bpf_stub.rs b/src/bpf_stub.rs index 6957e7d..baf24e4 100644 --- a/src/bpf_stub.rs +++ b/src/bpf_stub.rs @@ -1,29 +1,29 @@ -use std::marker::PhantomData; - -#[derive(Clone, Debug)] -pub struct FilterSkel<'a> { - _marker: PhantomData<&'a ()>, -} - -impl<'a> FilterSkel<'a> { - pub fn new() -> Self { - Self { _marker: PhantomData } - } -} - -#[derive(Clone, Debug, Default)] -pub struct FilterSkelBuilder; - -impl FilterSkelBuilder { - pub fn open(&self) -> Result> { - Err("BPF support disabled at build time".into()) - } -} - -pub struct FilterSkelOpen; - -impl FilterSkelOpen { - pub fn load(self) -> Result, Box> { - Err("BPF support disabled at build time".into()) - } -} +use std::marker::PhantomData; + +#[derive(Clone, Debug)] +pub struct FilterSkel<'a> { + _marker: PhantomData<&'a ()>, +} + +impl<'a> FilterSkel<'a> { + pub fn new() -> Self { + Self { _marker: PhantomData } + } +} + +#[derive(Clone, Debug, Default)] +pub struct FilterSkelBuilder; + +impl FilterSkelBuilder { + pub fn open(&self) -> Result> { + Err("BPF support disabled at build time".into()) + } +} + +pub struct FilterSkelOpen; + +impl FilterSkelOpen { + pub fn load(self) -> Result, Box> { + Err("BPF support disabled at build time".into()) + } +} diff --git a/src/captcha_server.rs b/src/captcha_server.rs index f376955..b4bd9d3 100644 --- a/src/captcha_server.rs +++ b/src/captcha_server.rs @@ -1,153 +1,153 @@ -use axum::{ - body::{Body, Bytes}, - extract::{ConnectInfo, State}, - http::StatusCode, - response::{IntoResponse, Response}, - routing::{get, post}, - Router, -}; -use log::{error, info, warn}; -use std::collections::HashMap; -use std::net::SocketAddr; -use tokio::net::TcpListener; - -/// Start the captcha verification server on port 9181 -pub async fn start_captcha_server() -> anyhow::Result<()> { - let app = Router::new() - .route("/cgi-bin/captcha/verify", post(handle_captcha_verification)) - .route("/health", get(health_check)); - - let addr = "127.0.0.1:9181"; - let listener = TcpListener::bind(addr).await?; - info!("Starting captcha verification server on: {}", addr); - - axum::serve( - listener, - app.into_make_service_with_connect_info::(), - ) - .await?; - - Ok(()) -} - -/// Health check endpoint -async fn health_check() -> &'static str { - "OK" -} - -/// Handle captcha verification requests -async fn handle_captcha_verification( - ConnectInfo(peer_addr): ConnectInfo, - State(()): State<()>, - body: Bytes, -) -> Response { - use crate::waf::actions::captcha::{validate_and_mark_captcha, apply_captcha_challenge}; - - info!("Starting captcha verification handler from: {} with body size: {}", peer_addr, body.len()); - - // Parse form data from request body - let form_data: HashMap = match String::from_utf8(body.to_vec()) { - Ok(body_str) => { - info!("Captcha verification request body: {}", body_str); - url::form_urlencoded::parse(body_str.as_bytes()) - .into_owned() - .collect() - } - Err(e) => { - error!("Failed to parse captcha verification request body as UTF-8: {}", e); - return (StatusCode::BAD_REQUEST, "Invalid request body").into_response(); - } - }; - - info!("Parsed form data: {:?}", form_data); - - // Extract captcha response and JWT token from form data - let captcha_response = match form_data.get("captcha_response") { - Some(response) => response.clone(), - None => { - warn!("Missing captcha_response in verification request from {}", peer_addr.ip()); - return (StatusCode::BAD_REQUEST, "Missing captcha_response").into_response(); - } - }; - - let jwt_token = match form_data.get("jwt_token") { - Some(token) => token.clone(), - None => { - warn!("Missing jwt_token in verification request from {}", peer_addr.ip()); - return (StatusCode::BAD_REQUEST, "Missing jwt_token").into_response(); - } - }; - - // Get user agent from request (would need to be passed in a real implementation) - let user_agent = String::from("Mozilla/5.0"); // Placeholder - - // Validate captcha and mark token as validated - match validate_and_mark_captcha( - captcha_response, - jwt_token.clone(), - peer_addr.ip().to_string(), - Some(user_agent), - ) - .await - { - Ok(true) => { - info!("Captcha verification successful for IP: {}", peer_addr.ip()); - - // Return 302 redirect with Set-Cookie header - Response::builder() - .status(StatusCode::FOUND) - .header("Location", "/") - .header( - "Set-Cookie", - format!( - "captcha_token={}; Path=/; Max-Age=3600; HttpOnly; SameSite=Lax", - jwt_token - ), - ) - .header("Cache-Control", "no-cache, no-store, must-revalidate") - .header("Pragma", "no-cache") - .header("Expires", "0") - .body(Body::empty()) - .unwrap() - } - Ok(false) => { - warn!("Captcha verification failed for IP: {}", peer_addr.ip()); - - // Generate failure page with retry option - let failure_html = apply_captcha_challenge().unwrap_or_else(|_| { - r#" - - - Verification Failed - - - -

Verification Failed

-
-

Captcha verification failed. Please try again.

-

Return to main page

-
- -"# - .to_string() - }); - - Response::builder() - .status(StatusCode::FORBIDDEN) - .header("Content-Type", "text/html; charset=utf-8") - .header("Cache-Control", "no-cache, no-store, must-revalidate") - .header("Pragma", "no-cache") - .header("Expires", "0") - .body(Body::from(failure_html)) - .unwrap() - } - Err(e) => { - error!("Captcha verification error for IP {}: {}", peer_addr.ip(), e); - (StatusCode::INTERNAL_SERVER_ERROR, "Verification error").into_response() - } - } -} - +use axum::{ + body::{Body, Bytes}, + extract::{ConnectInfo, State}, + http::StatusCode, + response::{IntoResponse, Response}, + routing::{get, post}, + Router, +}; +use log::{error, info, warn}; +use std::collections::HashMap; +use std::net::SocketAddr; +use tokio::net::TcpListener; + +/// Start the captcha verification server on port 9181 +pub async fn start_captcha_server() -> anyhow::Result<()> { + let app = Router::new() + .route("/cgi-bin/captcha/verify", post(handle_captcha_verification)) + .route("/health", get(health_check)); + + let addr = "127.0.0.1:9181"; + let listener = TcpListener::bind(addr).await?; + info!("Starting captcha verification server on: {}", addr); + + axum::serve( + listener, + app.into_make_service_with_connect_info::(), + ) + .await?; + + Ok(()) +} + +/// Health check endpoint +async fn health_check() -> &'static str { + "OK" +} + +/// Handle captcha verification requests +async fn handle_captcha_verification( + ConnectInfo(peer_addr): ConnectInfo, + State(()): State<()>, + body: Bytes, +) -> Response { + use crate::waf::actions::captcha::{validate_and_mark_captcha, apply_captcha_challenge}; + + info!("Starting captcha verification handler from: {} with body size: {}", peer_addr, body.len()); + + // Parse form data from request body + let form_data: HashMap = match String::from_utf8(body.to_vec()) { + Ok(body_str) => { + info!("Captcha verification request body: {}", body_str); + url::form_urlencoded::parse(body_str.as_bytes()) + .into_owned() + .collect() + } + Err(e) => { + error!("Failed to parse captcha verification request body as UTF-8: {}", e); + return (StatusCode::BAD_REQUEST, "Invalid request body").into_response(); + } + }; + + info!("Parsed form data: {:?}", form_data); + + // Extract captcha response and JWT token from form data + let captcha_response = match form_data.get("captcha_response") { + Some(response) => response.clone(), + None => { + warn!("Missing captcha_response in verification request from {}", peer_addr.ip()); + return (StatusCode::BAD_REQUEST, "Missing captcha_response").into_response(); + } + }; + + let jwt_token = match form_data.get("jwt_token") { + Some(token) => token.clone(), + None => { + warn!("Missing jwt_token in verification request from {}", peer_addr.ip()); + return (StatusCode::BAD_REQUEST, "Missing jwt_token").into_response(); + } + }; + + // Get user agent from request (would need to be passed in a real implementation) + let user_agent = String::from("Mozilla/5.0"); // Placeholder + + // Validate captcha and mark token as validated + match validate_and_mark_captcha( + captcha_response, + jwt_token.clone(), + peer_addr.ip().to_string(), + Some(user_agent), + ) + .await + { + Ok(true) => { + info!("Captcha verification successful for IP: {}", peer_addr.ip()); + + // Return 302 redirect with Set-Cookie header + Response::builder() + .status(StatusCode::FOUND) + .header("Location", "/") + .header( + "Set-Cookie", + format!( + "captcha_token={}; Path=/; Max-Age=3600; HttpOnly; SameSite=Lax", + jwt_token + ), + ) + .header("Cache-Control", "no-cache, no-store, must-revalidate") + .header("Pragma", "no-cache") + .header("Expires", "0") + .body(Body::empty()) + .unwrap() + } + Ok(false) => { + warn!("Captcha verification failed for IP: {}", peer_addr.ip()); + + // Generate failure page with retry option + let failure_html = apply_captcha_challenge().unwrap_or_else(|_| { + r#" + + + Verification Failed + + + +

Verification Failed

+
+

Captcha verification failed. Please try again.

+

Return to main page

+
+ +"# + .to_string() + }); + + Response::builder() + .status(StatusCode::FORBIDDEN) + .header("Content-Type", "text/html; charset=utf-8") + .header("Cache-Control", "no-cache, no-store, must-revalidate") + .header("Pragma", "no-cache") + .header("Expires", "0") + .body(Body::from(failure_html)) + .unwrap() + } + Err(e) => { + error!("Captcha verification error for IP {}: {}", peer_addr.ip(), e); + (StatusCode::INTERNAL_SERVER_ERROR, "Verification error").into_response() + } + } +} + diff --git a/src/cli.rs b/src/cli.rs index f8ea421..7481757 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,1233 +1,1233 @@ -use std::{path::PathBuf, env}; - -use anyhow::Result; -use clap::Parser; -use clap::ValueEnum; -use serde::{Deserialize, Serialize}; - -use crate::waf::actions::captcha::CaptchaProvider; - -/// TLS operating mode -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ValueEnum)] -#[serde(rename_all = "lowercase")] -pub enum TlsMode { - /// TLS is disabled - Disabled, -} - -/// Application operating mode -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ValueEnum)] -#[serde(rename_all = "lowercase")] -pub enum AppMode { - /// Agent mode: Only access rules and monitoring (no proxy) - Agent, - /// Proxy mode: Full reverse proxy functionality - Proxy, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Config { - #[serde(default = "default_mode")] - pub mode: String, - - /// Number of Tokio worker threads. In agent mode, defaults to 0 (single-threaded). - /// Set to a positive number to force multi-threaded runtime. - /// In proxy mode, defaults to number of CPUs. - #[serde(default)] - pub worker_threads: Option, - - // Global server options (moved from server section) - #[serde(default)] - pub redis: RedisConfig, - #[serde(default)] - pub network: NetworkConfig, - #[serde(default, alias = "arxignis")] - pub platform: Gen0SecConfig, - #[serde(default)] - pub geoip: GeoipConfig, - #[serde(default)] - pub content_scanning: ContentScanningCliConfig, - #[serde(default)] - pub logging: LoggingConfig, - #[serde(default)] - pub bpf_stats: BpfStatsConfig, - #[serde(default)] - pub tcp_fingerprint: TcpFingerprintConfig, - #[serde(default)] - pub daemon: DaemonConfig, - #[serde(default)] - pub pingora: PingoraConfig, - #[serde(default)] - pub acme: AcmeConfig, -} - -fn default_mode() -> String { "agent".to_string() } - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct ProxyProtocolConfig { - #[serde(default = "default_proxy_protocol_enabled")] - pub enabled: bool, - #[serde(default = "default_proxy_protocol_timeout")] - pub timeout_ms: u64, -} - -fn default_proxy_protocol_enabled() -> bool { false } -fn default_proxy_protocol_timeout() -> u64 { 1000 } - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct HealthCheckConfig { - #[serde(default = "default_health_check_enabled")] - pub enabled: bool, - #[serde(default = "default_health_check_endpoint")] - pub endpoint: String, - #[serde(default = "default_health_check_port")] - pub port: String, - #[serde(default = "default_health_check_methods")] - pub methods: Vec, - #[serde(default = "default_health_check_allowed_cidrs")] - pub allowed_cidrs: Vec, -} - -fn default_health_check_enabled() -> bool { true } -fn default_health_check_endpoint() -> String { "/health".to_string() } -fn default_health_check_port() -> String { "0.0.0.0:8080".to_string() } -fn default_health_check_methods() -> Vec { vec!["GET".to_string(), "HEAD".to_string()] } -fn default_health_check_allowed_cidrs() -> Vec { vec![] } - - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct RedisConfig { - #[serde(default)] - pub url: String, - #[serde(default)] - pub prefix: String, - /// Redis SSL/TLS configuration - #[serde(default)] - pub ssl: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RedisSslConfig { - /// Path to CA certificate file (PEM format) - pub ca_cert_path: Option, - /// Path to client certificate file (PEM format, optional) - pub client_cert_path: Option, - /// Path to client private key file (PEM format, optional) - pub client_key_path: Option, - /// Skip certificate verification (for testing with self-signed certs) - #[serde(default)] - pub insecure: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct NetworkConfig { - #[serde(default)] - pub iface: String, - #[serde(default)] - pub ifaces: Vec, - #[serde(default)] - pub disable_xdp: bool, - /// IP version support mode: "ipv4", "ipv6", or "both" (default: "both") - /// Note: XDP requires IPv6 to be enabled at kernel level for attachment, - /// even in IPv4-only mode. Set to "ipv4" to skip XDP if IPv6 cannot be enabled. - #[serde(default = "default_ip_version")] - pub ip_version: String, - /// Firewall backend mode: auto, xdp, nftables, iptables, none - #[serde(default)] - pub firewall_mode: crate::firewall::FirewallMode, -} - -fn default_ip_version() -> String { - "both".to_string() -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct Gen0SecConfig { - #[serde(default)] - pub api_key: String, - #[serde(default)] - pub workspace_id: String, - #[serde(default = "default_base_url")] - pub base_url: String, - /// Threat MMDB database configuration - #[serde(default)] - pub threat: GeoipDatabaseConfig, - #[serde(default = "default_log_sending_enabled")] - pub log_sending_enabled: bool, - #[serde(default = "default_include_response_body")] - pub include_response_body: bool, - #[serde(default = "default_max_body_size")] - pub max_body_size: usize, - #[serde(default)] - pub captcha: CaptchaConfig, -} - -fn default_base_url() -> String { - "https://api.gen0sec.com/v1".to_string() -} - - - -fn default_log_sending_enabled() -> bool { - true -} - -fn default_include_response_body() -> bool { - true -} - -fn default_max_body_size() -> usize { - 1024 * 1024 // 1MB -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct GeoipDatabaseConfig { - /// URL to download GeoIP MMDB file - #[serde(default)] - pub url: String, - /// Optional local path to store/read the GeoIP MMDB file - #[serde(default)] - pub path: Option, - /// Optional custom headers to add to download requests - #[serde(default)] - pub headers: Option>, - /// Refresh interval in seconds (optional, overrides parent refresh_secs if set) - #[serde(default)] - pub refresh_secs: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct GeoipConfig { - /// Country database configuration - #[serde(default)] - pub country: GeoipDatabaseConfig, - /// ASN database configuration - #[serde(default)] - pub asn: GeoipDatabaseConfig, - /// City database configuration - #[serde(default)] - pub city: GeoipDatabaseConfig, - /// How often to refresh the GeoIP MMDB from the remote URL (seconds). Default: 28800 (8 hours) - #[serde(default = "default_geoip_refresh_secs")] - pub refresh_secs: u64, -} - -fn default_geoip_refresh_secs() -> u64 { - 28800 // 8 hours -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct ContentScanningCliConfig { - #[serde(default = "default_scanning_enabled")] - pub enabled: bool, - #[serde(default = "default_clamav_server")] - pub clamav_server: String, - #[serde(default = "default_max_file_size")] - pub max_file_size: usize, - #[serde(default)] - pub scan_content_types: Vec, - #[serde(default)] - pub skip_extensions: Vec, - #[serde(default = "default_scan_expression")] - pub scan_expression: String, -} - -fn default_scanning_enabled() -> bool { false } -fn default_clamav_server() -> String { "localhost:3310".to_string() } -fn default_max_file_size() -> usize { 10 * 1024 * 1024 } -fn default_scan_expression() -> String { "http.request.method eq \"POST\" or http.request.method eq \"PUT\"".to_string() } - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct LoggingConfig { - #[serde(default)] - pub level: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct CaptchaConfig { - #[serde(default)] - pub site_key: Option, - #[serde(default)] - pub secret_key: Option, - #[serde(default)] - pub jwt_secret: Option, - #[serde(default)] - pub provider: String, - #[serde(default)] - pub token_ttl: u64, - #[serde(default)] - pub cache_ttl: u64, -} - -impl Config { - pub fn load_from_file(path: &PathBuf) -> Result { - let content = std::fs::read_to_string(path)?; - let config: Config = serde_yaml::from_str(&content)?; - Ok(config) - } - - pub fn default() -> Self { - Self { - mode: "agent".to_string(), - worker_threads: None, - redis: RedisConfig { - url: "redis://127.0.0.1/0".to_string(), - prefix: "ax:synapse".to_string(), - ssl: None, - }, - network: NetworkConfig { - iface: "eth0".to_string(), - ifaces: vec![], - disable_xdp: false, - ip_version: "both".to_string(), - firewall_mode: crate::firewall::FirewallMode::default(), - }, - platform: Gen0SecConfig { - api_key: "".to_string(), - workspace_id: "".to_string(), - base_url: "https://api.gen0sec.com/v1".to_string(), - threat: GeoipDatabaseConfig::default(), - log_sending_enabled: true, - include_response_body: true, - max_body_size: 1024 * 1024, // 1MB - captcha: CaptchaConfig { - site_key: None, - secret_key: None, - jwt_secret: None, - provider: "hcaptcha".to_string(), - token_ttl: 7200, - cache_ttl: 300, - }, - }, - geoip: GeoipConfig { - country: GeoipDatabaseConfig { - url: "".to_string(), - path: None, - headers: None, - refresh_secs: None, - }, - asn: GeoipDatabaseConfig { - url: "".to_string(), - path: None, - headers: None, - refresh_secs: None, - }, - city: GeoipDatabaseConfig { - url: "".to_string(), - path: None, - headers: None, - refresh_secs: None, - }, - refresh_secs: 28800, - }, - content_scanning: ContentScanningCliConfig { - enabled: false, - clamav_server: "localhost:3310".to_string(), - max_file_size: 10 * 1024 * 1024, - scan_content_types: vec![ - "text/html".to_string(), - "application/x-www-form-urlencoded".to_string(), - "multipart/form-data".to_string(), - "application/json".to_string(), - "text/plain".to_string(), - ], - skip_extensions: vec![], - scan_expression: default_scan_expression(), - }, - logging: LoggingConfig { - level: "info".to_string(), - }, - bpf_stats: BpfStatsConfig::default(), - tcp_fingerprint: TcpFingerprintConfig::default(), - daemon: DaemonConfig::default(), - pingora: PingoraConfig::default(), - acme: AcmeConfig::default(), - } - } - - pub fn merge_with_args(&mut self, args: &Args) { - // Override config values with command line arguments if provided - - if !args.ifaces.is_empty() { - self.network.ifaces = args.ifaces.clone(); - } - if args.disable_xdp { - self.network.disable_xdp = true; - } - if let Some(ref mode) = args.firewall_mode { - self.network.firewall_mode = match mode.to_lowercase().as_str() { - "auto" => crate::firewall::FirewallMode::Auto, - "xdp" => crate::firewall::FirewallMode::Xdp, - "nftables" => crate::firewall::FirewallMode::Nftables, - "iptables" => crate::firewall::FirewallMode::Iptables, - "none" => crate::firewall::FirewallMode::None, - _ => { - log::warn!("Unknown firewall mode '{}', using auto", mode); - crate::firewall::FirewallMode::Auto - } - }; - } - if let Some(api_key) = &args.arxignis_api_key { - self.platform.api_key = api_key.clone(); - } - if !args.arxignis_workspace_id.is_empty() { - self.platform.workspace_id = args.arxignis_workspace_id.clone(); - } - if !args.arxignis_base_url.is_empty() && args.arxignis_base_url != "https://api.gen0sec.com/v1" { - self.platform.base_url = args.arxignis_base_url.clone(); - } - if let Some(log_sending_enabled) = args.arxignis_log_sending_enabled { - self.platform.log_sending_enabled = log_sending_enabled; - } - self.platform.include_response_body = args.arxignis_include_response_body; - self.platform.max_body_size = args.arxignis_max_body_size; - if args.captcha_site_key.is_some() { - self.platform.captcha.site_key = args.captcha_site_key.clone(); - } - if args.captcha_secret_key.is_some() { - self.platform.captcha.secret_key = args.captcha_secret_key.clone(); - } - if args.captcha_jwt_secret.is_some() { - self.platform.captcha.jwt_secret = args.captcha_jwt_secret.clone(); - } - if let Some(provider) = &args.captcha_provider { - self.platform.captcha.provider = format!("{:?}", provider).to_lowercase(); - } - - // Proxy protocol configuration overrides - if args.proxy_protocol_enabled { - self.pingora.proxy_protocol.enabled = true; - } - if args.proxy_protocol_timeout != 1000 { - self.pingora.proxy_protocol.timeout_ms = args.proxy_protocol_timeout; - } - - // Daemon configuration overrides - if args.daemon { - self.daemon.enabled = true; - } - if args.daemon_pid_file != "/var/run/synapse.pid" { - self.daemon.pid_file = args.daemon_pid_file.clone(); - } - if args.daemon_working_dir != "/" { - self.daemon.working_directory = args.daemon_working_dir.clone(); - } - if args.daemon_stdout != "/var/log/synapse.out" { - self.daemon.stdout = args.daemon_stdout.clone(); - } - if args.daemon_stderr != "/var/log/synapse.err" { - self.daemon.stderr = args.daemon_stderr.clone(); - } - if args.daemon_user.is_some() { - self.daemon.user = args.daemon_user.clone(); - } - if args.daemon_group.is_some() { - self.daemon.group = args.daemon_group.clone(); - } - - // Redis configuration overrides - if !args.redis_url.is_empty() && args.redis_url != "redis://127.0.0.1/0" { - self.redis.url = args.redis_url.clone(); - } - if !args.redis_prefix.is_empty() && args.redis_prefix != "ax:synapse" { - self.redis.prefix = args.redis_prefix.clone(); - } - } - - pub fn validate_required_fields(&mut self, args: &Args) -> Result<()> { - // Check if arxignis API key is provided - only warn if not provided - // (to support old config format that doesn't have this field) - if args.arxignis_api_key.is_none() && self.platform.api_key.is_empty() { - log::warn!("Gen0Sec API key not provided. Some features may not work."); - } - - Ok(()) - } - - pub fn load_from_args(args: &Args) -> Result { - let mut config = if let Some(config_path) = &args.config { - Self::load_from_file(config_path)? - } else { - Self::default() - }; - - config.merge_with_args(args); - config.apply_env_overrides(); - config.validate_required_fields(args)?; - Ok(config) - } - - pub fn apply_env_overrides(&mut self) { - // Helper function to get env var, preferring non-prefixed version - // AX_ prefix is deprecated and will be removed in a future version - fn get_env(name: &str) -> Option { - // Prefer non-prefixed version - if let Ok(val) = env::var(name) { - return Some(val); - } - // Fallback to AX_ prefix with deprecation warning - let ax_name = format!("AX_{}", name); - if let Ok(val) = env::var(&ax_name) { - log::warn!("Environment variable '{}' is deprecated, use '{}' instead. The AX_ prefix will be removed in a future version.", ax_name, name); - return Some(val); - } - None - } - - // Helper for arxignis vars: name > AX_name > AX_ARXIGNIS_name > ARXIGNIS_name (backward compat) - // ARXIGNIS prefix is deprecated and will be removed in a future version - fn get_env_arxignis(name: &str) -> Option { - // Prefer non-prefixed version - if let Ok(val) = env::var(name) { - return Some(val); - } - // Fallback to AX_ prefix with deprecation warning - let ax_name = format!("AX_{}", name); - if let Ok(val) = env::var(&ax_name) { - log::warn!("Environment variable '{}' is deprecated, use '{}' instead. The AX_ prefix will be removed in a future version.", ax_name, name); - return Some(val); - } - // Fallback to AX_ARXIGNIS_ prefix with deprecation warning - let ax_arxignis_name = format!("AX_ARXIGNIS_{}", name); - if let Ok(val) = env::var(&ax_arxignis_name) { - log::warn!("Environment variable '{}' is deprecated, use '{}' instead. The ARXIGNIS prefix will be removed in a future version.", ax_arxignis_name, name); - return Some(val); - } - // Fallback to ARXIGNIS_ prefix with deprecation warning - let arxignis_name = format!("ARXIGNIS_{}", name); - if let Ok(val) = env::var(&arxignis_name) { - log::warn!("Environment variable '{}' is deprecated, use '{}' instead. The ARXIGNIS prefix will be removed in a future version.", arxignis_name, name); - return Some(val); - } - None - } - - // Mode override - if let Some(val) = get_env("MODE") { - self.mode = val; - } - - // Redis configuration overrides - if let Some(val) = get_env("REDIS_URL") { - self.redis.url = val; - } - if let Some(val) = get_env("REDIS_PREFIX") { - self.redis.prefix = val; - } - - // Redis SSL configuration overrides - // Read all SSL environment variables once - let ca_cert_path = get_env("REDIS_SSL_CA_CERT_PATH"); - let client_cert_path = get_env("REDIS_SSL_CLIENT_CERT_PATH"); - let client_key_path = get_env("REDIS_SSL_CLIENT_KEY_PATH"); - let insecure_val = get_env("REDIS_SSL_INSECURE"); - - // If any SSL env var is set, ensure SSL config exists - if ca_cert_path.is_some() - || client_cert_path.is_some() - || client_key_path.is_some() - || insecure_val.is_some() { - - // Create SSL config if it doesn't exist - if self.redis.ssl.is_none() { - // Parse insecure value if provided, default to false - let insecure_default = insecure_val - .as_ref() - .and_then(|v| v.parse::().ok()) - .unwrap_or(false); - - self.redis.ssl = Some(RedisSslConfig { - ca_cert_path: None, - client_cert_path: None, - client_key_path: None, - insecure: insecure_default, - }); - } - - // Update the SSL config with values from environment variables - let ssl = self.redis.ssl.as_mut().expect("SSL config should exist here"); - if let Some(val) = ca_cert_path { - ssl.ca_cert_path = Some(val); - } - if let Some(val) = client_cert_path { - ssl.client_cert_path = Some(val); - } - if let Some(val) = client_key_path { - ssl.client_key_path = Some(val); - } - if let Some(val) = insecure_val { - if let Ok(insecure) = val.parse::() { - ssl.insecure = insecure; - } - } - } - - // Network configuration overrides - if let Some(val) = get_env("NETWORK_IFACE") { - self.network.iface = val; - } - if let Some(val) = get_env("NETWORK_IFACES") { - self.network.ifaces = val.split(',').map(|s| s.trim().to_string()).collect(); - } - if let Some(val) = get_env("NETWORK_DISABLE_XDP") { - self.network.disable_xdp = val.parse().unwrap_or(false); - } - if let Some(val) = get_env("FIREWALL_MODE").or_else(|| get_env("NETWORK_FIREWALL_MODE")) { - self.network.firewall_mode = match val.to_lowercase().as_str() { - "auto" => crate::firewall::FirewallMode::Auto, - "xdp" => crate::firewall::FirewallMode::Xdp, - "nftables" => crate::firewall::FirewallMode::Nftables, - "iptables" => crate::firewall::FirewallMode::Iptables, - "none" => crate::firewall::FirewallMode::None, - _ => crate::firewall::FirewallMode::Auto, - }; - } - if let Some(val) = get_env("NETWORK_IP_VERSION") { - // Validate ip_version value - match val.as_str() { - "ipv4" | "ipv6" | "both" => { - self.network.ip_version = val; - } - _ => { - log::warn!("Invalid NETWORK_IP_VERSION value '{}', using default 'both'. Valid values: ipv4, ipv6, both", val); - } - } - } - - // Gen0Sec configuration overrides - // Supports: AX_API_KEY, API_KEY, AX_ARXIGNIS_API_KEY, ARXIGNIS_API_KEY (backward compat) - if let Some(val) = get_env_arxignis("API_KEY") { - self.platform.api_key = val; - } - if let Some(val) = get_env_arxignis("WORKSPACE_ID") { - self.platform.workspace_id = val; - } - if let Some(val) = get_env_arxignis("BASE_URL") { - self.platform.base_url = val; - } - if let Some(val) = get_env_arxignis("LOG_SENDING_ENABLED") { - if let Ok(parsed) = val.parse::() { - self.platform.log_sending_enabled = parsed; - } - } - if let Some(val) = get_env_arxignis("INCLUDE_RESPONSE_BODY") { - self.platform.include_response_body = val.parse().unwrap_or(true); - } - if let Some(val) = get_env_arxignis("MAX_BODY_SIZE") { - self.platform.max_body_size = val.parse().unwrap_or(1024 * 1024); - } - - // Logging configuration overrides - if let Some(val) = get_env("LOGGING_LEVEL") { - self.logging.level = val; - } - - // Content scanning overrides - if let Some(val) = get_env("CONTENT_SCANNING_ENABLED") { - self.content_scanning.enabled = val.parse().unwrap_or(false); - } - if let Some(val) = get_env("CLAMAV_SERVER") { - self.content_scanning.clamav_server = val; - } - if let Some(val) = get_env("CONTENT_MAX_FILE_SIZE") { - self.content_scanning.max_file_size = val.parse().unwrap_or(10 * 1024 * 1024); - } - if let Some(val) = get_env("CONTENT_SCAN_CONTENT_TYPES") { - self.content_scanning.scan_content_types = val.split(',').map(|s| s.trim().to_string()).collect(); - } - if let Some(val) = get_env("CONTENT_SKIP_EXTENSIONS") { - self.content_scanning.skip_extensions = val.split(',').map(|s| s.trim().to_string()).collect(); - } - if let Some(val) = get_env("CONTENT_SCAN_EXPRESSION") { - self.content_scanning.scan_expression = val; - } - - // Captcha configuration overrides - if let Some(val) = get_env("CAPTCHA_SITE_KEY") { - self.platform.captcha.site_key = Some(val); - } - if let Some(val) = get_env("CAPTCHA_SECRET_KEY") { - self.platform.captcha.secret_key = Some(val); - } - if let Some(val) = get_env("CAPTCHA_JWT_SECRET") { - self.platform.captcha.jwt_secret = Some(val); - } - if let Some(val) = get_env("CAPTCHA_PROVIDER") { - self.platform.captcha.provider = val; - } - if let Some(val) = get_env("CAPTCHA_TOKEN_TTL") { - self.platform.captcha.token_ttl = val.parse().unwrap_or(7200); - } - if let Some(val) = get_env("CAPTCHA_CACHE_TTL") { - self.platform.captcha.cache_ttl = val.parse().unwrap_or(300); - } - - // Proxy protocol configuration overrides - if let Some(val) = get_env("PROXY_PROTOCOL_ENABLED") { - self.pingora.proxy_protocol.enabled = val.parse().unwrap_or(false); - } - if let Some(val) = get_env("PROXY_PROTOCOL_TIMEOUT") { - self.pingora.proxy_protocol.timeout_ms = val.parse().unwrap_or(1000); - } - - // Daemon configuration overrides - if let Some(val) = get_env("DAEMON_ENABLED") { - self.daemon.enabled = val.parse().unwrap_or(false); - } - if let Some(val) = get_env("DAEMON_PID_FILE") { - self.daemon.pid_file = val; - } - if let Some(val) = get_env("DAEMON_WORKING_DIRECTORY") { - self.daemon.working_directory = val; - } - if let Some(val) = get_env("DAEMON_STDOUT") { - self.daemon.stdout = val; - } - if let Some(val) = get_env("DAEMON_STDERR") { - self.daemon.stderr = val; - } - if let Some(val) = get_env("DAEMON_USER") { - self.daemon.user = Some(val); - } - if let Some(val) = get_env("DAEMON_GROUP") { - self.daemon.group = Some(val); - } - if let Some(val) = get_env("DAEMON_CHOWN_PID_FILE") { - self.daemon.chown_pid_file = val.parse().unwrap_or(true); - } - } -} - -#[derive(Parser, Debug, Clone)] -#[command(author, version, about, long_about = None)] -pub struct Args { - /// Path to configuration file (YAML format) - #[arg(long, short = 'c')] - pub config: Option, - - /// Path to security rules configuration file (access rules and WAF rules) - /// Used when no Gen0Sec API key is provided as a fallback - #[arg(long, default_value = "security_rules.yaml")] - pub security_rules_config: PathBuf, - - /// Clear a specific certificate from local filesystem and Redis - #[arg(long)] - pub clear_certificate: Option, - - /// Redis connection URL for ACME cache storage. - #[arg(long, default_value = "redis://127.0.0.1/0")] - pub redis_url: String, - - /// Namespace prefix for Redis ACME cache entries. - #[arg(long, default_value = "ax:synapse")] - pub redis_prefix: String, - - /// The network interface to attach the XDP program to. - #[arg(short, long, default_value = "eth0")] - pub iface: String, - - /// Additional network interfaces for XDP attach (comma-separated). If set, overrides --iface. - #[arg(long, value_delimiter = ',', num_args = 0..)] - pub ifaces: Vec, - - #[arg(long)] - pub arxignis_api_key: Option, - - /// Base URL for Gen0Sec API. - #[arg(long, default_value = "https://api.gen0sec.com/v1")] - pub arxignis_base_url: String, - - /// Workspace ID for agent identity - #[arg(long, default_value = "")] - pub arxignis_workspace_id: String, - - /// Enable sending access logs to arxignis server - #[arg(long)] - pub arxignis_log_sending_enabled: Option, - - /// Include response body in access logs - #[arg(long, default_value_t = true)] - pub arxignis_include_response_body: bool, - - /// Maximum size for request/response bodies in access logs (bytes) - #[arg(long, default_value = "1048576")] - pub arxignis_max_body_size: usize, - - /// Log level (error, warn, info, debug, trace) - #[arg(long, value_enum, default_value_t = LogLevel::Info)] - pub log_level: LogLevel, - - /// Disable XDP packet filtering (run without BPF/XDP) - #[arg(long, default_value_t = false)] - pub disable_xdp: bool, - - /// Firewall backend mode (auto, xdp, nftables, iptables, none) - #[arg(long)] - pub firewall_mode: Option, - - /// Captcha site key for security verification - #[arg(long)] - pub captcha_site_key: Option, - - /// Captcha secret key for security verification - #[arg(long)] - pub captcha_secret_key: Option, - - /// JWT secret key for captcha token signing - #[arg(long)] - pub captcha_jwt_secret: Option, - - /// Captcha provider (hcaptcha, recaptcha, turnstile) - #[arg(long, value_enum)] - pub captcha_provider: Option, - - - /// Captcha token TTL in seconds - #[arg(long, default_value = "7200")] - pub captcha_token_ttl: u64, - - /// Captcha validation cache TTL in seconds - #[arg(long, default_value = "300")] - pub captcha_cache_ttl: u64, - - /// Enable PROXY protocol support for TCP connections - #[arg(long, default_value_t = false)] - pub proxy_protocol_enabled: bool, - - /// PROXY protocol timeout in milliseconds - #[arg(long, default_value = "1000")] - pub proxy_protocol_timeout: u64, - - /// Run as daemon in background - #[arg(long, short = 'd', default_value_t = false)] - pub daemon: bool, - - /// PID file path for daemon mode - #[arg(long, default_value = "/var/run/synapse.pid")] - pub daemon_pid_file: String, - - /// Working directory for daemon mode - #[arg(long, default_value = "/")] - pub daemon_working_dir: String, - - /// Stdout log file for daemon mode - #[arg(long, default_value = "/var/log/synapse.out")] - pub daemon_stdout: String, - - /// Stderr log file for daemon mode - #[arg(long, default_value = "/var/log/synapse.err")] - pub daemon_stderr: String, - - /// User to run daemon as - #[arg(long)] - pub daemon_user: Option, - - /// Group to run daemon as - #[arg(long)] - pub daemon_group: Option, -} - -#[derive(Copy, Clone, Debug, ValueEnum)] -pub enum LogLevel { - Error, - Warn, - Info, - Debug, - Trace, -} - -impl LogLevel { - pub fn to_level_filter(self) -> log::LevelFilter { - match self { - LogLevel::Error => log::LevelFilter::Error, - LogLevel::Warn => log::LevelFilter::Warn, - LogLevel::Info => log::LevelFilter::Info, - LogLevel::Debug => log::LevelFilter::Debug, - LogLevel::Trace => log::LevelFilter::Trace, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct BpfStatsConfig { - #[serde(default = "default_bpf_stats_enabled")] - pub enabled: bool, - #[serde(default = "default_bpf_stats_log_interval")] - pub log_interval_secs: u64, - #[serde(default = "default_bpf_stats_enable_dropped_ip_events")] - pub enable_dropped_ip_events: bool, - #[serde(default = "default_bpf_stats_dropped_ip_events_interval")] - pub dropped_ip_events_interval_secs: u64, -} - -fn default_bpf_stats_enabled() -> bool { true } -fn default_bpf_stats_log_interval() -> u64 { 60 } -fn default_bpf_stats_enable_dropped_ip_events() -> bool { true } -fn default_bpf_stats_dropped_ip_events_interval() -> u64 { 30 } - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct TcpFingerprintConfig { - #[serde(default = "default_tcp_fingerprint_enabled")] - pub enabled: bool, - #[serde(default = "default_tcp_fingerprint_log_interval")] - pub log_interval_secs: u64, - #[serde(default = "default_tcp_fingerprint_enable_fingerprint_events")] - pub enable_fingerprint_events: bool, - #[serde(default = "default_tcp_fingerprint_events_interval")] - pub fingerprint_events_interval_secs: u64, - #[serde(default = "default_tcp_fingerprint_min_packet_count")] - pub min_packet_count: u32, - #[serde(default = "default_tcp_fingerprint_min_connection_duration")] - pub min_connection_duration_secs: u64, -} - -fn default_tcp_fingerprint_enabled() -> bool { true } -fn default_tcp_fingerprint_log_interval() -> u64 { 60 } -fn default_tcp_fingerprint_enable_fingerprint_events() -> bool { true } -fn default_tcp_fingerprint_events_interval() -> u64 { 30 } -fn default_tcp_fingerprint_min_packet_count() -> u32 { 3 } -fn default_tcp_fingerprint_min_connection_duration() -> u64 { 1 } - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct DaemonConfig { - #[serde(default = "default_daemon_enabled")] - pub enabled: bool, - #[serde(default = "default_daemon_pid_file")] - pub pid_file: String, - #[serde(default = "default_daemon_working_directory")] - pub working_directory: String, - #[serde(default = "default_daemon_stdout")] - pub stdout: String, - #[serde(default = "default_daemon_stderr")] - pub stderr: String, - pub user: Option, - pub group: Option, - #[serde(default = "default_daemon_chown_pid_file")] - pub chown_pid_file: bool, -} - -fn default_daemon_enabled() -> bool { false } -fn default_daemon_pid_file() -> String { "/var/run/synapse.pid".to_string() } -fn default_daemon_working_directory() -> String { "/".to_string() } -fn default_daemon_stdout() -> String { "/var/log/synapse.out".to_string() } -fn default_daemon_stderr() -> String { "/var/log/synapse.err".to_string() } -fn default_daemon_chown_pid_file() -> bool { true } - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct PingoraConfig { - #[serde(default)] - pub proxy_address_http: String, - #[serde(default)] - pub proxy_address_tls: Option, - #[serde(default)] - pub proxy_certificates: Option, - #[serde(default = "default_pingora_tls_grade")] - pub proxy_tls_grade: String, - #[serde(default)] - pub default_certificate: Option, - #[serde(default)] - pub upstreams_conf: String, - #[serde(default)] - pub config_address: String, - #[serde(default = "default_pingora_config_api_enabled")] - pub config_api_enabled: bool, - #[serde(default)] - pub master_key: String, - #[serde(default = "default_pingora_log_level")] - pub log_level: String, - #[serde(default = "default_pingora_healthcheck_method")] - pub healthcheck_method: String, - #[serde(default = "default_pingora_healthcheck_interval")] - pub healthcheck_interval: u16, - #[serde(default)] - pub proxy_protocol: ProxyProtocolConfig, -} - -fn default_pingora_tls_grade() -> String { "medium".to_string() } -fn default_pingora_config_api_enabled() -> bool { true } -fn default_pingora_log_level() -> String { "debug".to_string() } -fn default_pingora_healthcheck_method() -> String { "HEAD".to_string() } -fn default_pingora_healthcheck_interval() -> u16 { 2 } - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AcmeConfig { - /// Enable embedded ACME server - #[serde(default = "default_acme_enabled")] - pub enabled: bool, - /// Port for ACME server (e.g., 9180) - #[serde(default = "default_acme_port")] - pub port: u16, - /// Email for ACME account - pub email: Option, - /// Storage path for certificates - #[serde(default = "default_acme_storage_path")] - pub storage_path: String, - /// Storage type: "file" or "redis" (defaults to "file", or "redis" if redis_url is set) - pub storage_type: Option, - /// Use development/staging ACME server - #[serde(default)] - pub development: bool, - /// Redis URL for storage (optional, uses global redis.url if not set) - pub redis_url: Option, -} - -impl Default for AcmeConfig { - fn default() -> Self { - Self { - enabled: default_acme_enabled(), - port: default_acme_port(), - email: None, - storage_path: default_acme_storage_path(), - storage_type: None, - development: false, - redis_url: None, - } - } -} - -fn default_acme_enabled() -> bool { false } -fn default_acme_port() -> u16 { 9180 } -fn default_acme_storage_path() -> String { "/tmp/synapse-acme".to_string() } - -impl PingoraConfig { - /// Convert PingoraConfig to AppConfig for compatibility with old proxy system - pub fn to_app_config(&self) -> crate::utils::structs::AppConfig { - let mut app_config = crate::utils::structs::AppConfig::default(); - app_config.proxy_address_http = self.proxy_address_http.clone(); - app_config.proxy_address_tls = self.proxy_address_tls.clone(); - app_config.proxy_certificates = self.proxy_certificates.clone(); - app_config.proxy_tls_grade = Some(self.proxy_tls_grade.clone()); - app_config.default_certificate = self.default_certificate.clone(); - app_config.upstreams_conf = self.upstreams_conf.clone(); - app_config.config_address = self.config_address.clone(); - app_config.config_api_enabled = self.config_api_enabled; - app_config.master_key = self.master_key.clone(); - app_config.healthcheck_method = self.healthcheck_method.clone(); - app_config.healthcheck_interval = self.healthcheck_interval; - app_config.proxy_protocol_enabled = self.proxy_protocol.enabled; - - // Parse config_address to local_server - if let Some((ip, port_str)) = self.config_address.split_once(':') { - if let Ok(port) = port_str.parse::() { - app_config.local_server = Some((ip.to_string(), port)); - } - } - - // Parse proxy_address_tls to proxy_port_tls - if let Some(ref tls_addr) = self.proxy_address_tls { - if let Some((_, port_str)) = tls_addr.split_once(':') { - if let Ok(port) = port_str.parse::() { - app_config.proxy_port_tls = Some(port); - } - } - } - - app_config - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::env; - use serial_test::serial; - - #[test] - fn test_redis_ssl_config_deserialize() { - let yaml = r#" -ca_cert_path: "/path/to/ca.crt" -client_cert_path: "/path/to/client.crt" -client_key_path: "/path/to/client.key" -insecure: true -"#; - let config: RedisSslConfig = serde_yaml::from_str(yaml).unwrap(); - assert_eq!(config.ca_cert_path, Some("/path/to/ca.crt".to_string())); - assert_eq!(config.client_cert_path, Some("/path/to/client.crt".to_string())); - assert_eq!(config.client_key_path, Some("/path/to/client.key".to_string())); - assert!(config.insecure); - } - - #[test] - fn test_redis_ssl_config_deserialize_minimal() { - let yaml = r#" -insecure: false -"#; - let config: RedisSslConfig = serde_yaml::from_str(yaml).unwrap(); - assert_eq!(config.ca_cert_path, None); - assert_eq!(config.client_cert_path, None); - assert_eq!(config.client_key_path, None); - assert!(!config.insecure); - } - - #[test] - fn test_redis_config_with_ssl() { - let yaml = r#" -url: "rediss://localhost:6379" -prefix: "test:prefix" -ssl: - ca_cert_path: "/path/to/ca.crt" - insecure: false -"#; - let config: RedisConfig = serde_yaml::from_str(yaml).unwrap(); - assert_eq!(config.url, "rediss://localhost:6379"); - assert_eq!(config.prefix, "test:prefix"); - assert!(config.ssl.is_some()); - let ssl = config.ssl.unwrap(); - assert_eq!(ssl.ca_cert_path, Some("/path/to/ca.crt".to_string())); - assert!(!ssl.insecure); - } - - #[test] - fn test_redis_config_without_ssl() { - let yaml = r#" -url: "redis://localhost:6379" -prefix: "test:prefix" -"#; - let config: RedisConfig = serde_yaml::from_str(yaml).unwrap(); - assert_eq!(config.url, "redis://localhost:6379"); - assert_eq!(config.prefix, "test:prefix"); - assert!(config.ssl.is_none()); - } - - #[test] - fn test_redis_config_default() { - let config = RedisConfig::default(); - assert_eq!(config.url, ""); - assert_eq!(config.prefix, ""); - assert!(config.ssl.is_none()); - } - - // Helper function to clean up SSL environment variables for test isolation - fn cleanup_redis_ssl_env_vars() { - unsafe { - env::remove_var("AX_REDIS_SSL_CA_CERT_PATH"); - env::remove_var("AX_REDIS_SSL_CLIENT_CERT_PATH"); - env::remove_var("AX_REDIS_SSL_CLIENT_KEY_PATH"); - env::remove_var("AX_REDIS_SSL_INSECURE"); - } - } - - #[test] - #[serial] - fn test_apply_env_overrides_redis_ssl_ca_cert() { - cleanup_redis_ssl_env_vars(); - - let mut config = Config::default(); - unsafe { - env::set_var("AX_REDIS_SSL_CA_CERT_PATH", "/test/ca.crt"); - } - - config.apply_env_overrides(); - - assert!(config.redis.ssl.is_some()); - assert_eq!(config.redis.ssl.as_ref().unwrap().ca_cert_path, Some("/test/ca.crt".to_string())); - - unsafe { - env::remove_var("AX_REDIS_SSL_CA_CERT_PATH"); - } - } - - #[test] - #[serial] - fn test_apply_env_overrides_redis_ssl_client_cert() { - cleanup_redis_ssl_env_vars(); - - let mut config = Config::default(); - unsafe { - env::set_var("AX_REDIS_SSL_CLIENT_CERT_PATH", "/test/client.crt"); - env::set_var("AX_REDIS_SSL_CLIENT_KEY_PATH", "/test/client.key"); - } - - config.apply_env_overrides(); - - assert!(config.redis.ssl.is_some()); - let ssl = config.redis.ssl.as_ref().unwrap(); - assert_eq!(ssl.client_cert_path, Some("/test/client.crt".to_string())); - assert_eq!(ssl.client_key_path, Some("/test/client.key".to_string())); - - unsafe { - env::remove_var("AX_REDIS_SSL_CLIENT_CERT_PATH"); - env::remove_var("AX_REDIS_SSL_CLIENT_KEY_PATH"); - } - } - - #[test] - #[serial] - fn test_apply_env_overrides_redis_ssl_insecure() { - cleanup_redis_ssl_env_vars(); - - let mut config = Config::default(); - unsafe { - env::set_var("AX_REDIS_SSL_INSECURE", "true"); - } - - config.apply_env_overrides(); - - assert!(config.redis.ssl.is_some()); - assert!(config.redis.ssl.as_ref().unwrap().insecure); - - unsafe { - env::remove_var("AX_REDIS_SSL_INSECURE"); - } - } - - #[test] - #[serial] - fn test_apply_env_overrides_redis_ssl_insecure_false() { - cleanup_redis_ssl_env_vars(); - - let mut config = Config::default(); - unsafe { - env::set_var("AX_REDIS_SSL_INSECURE", "false"); - } - - config.apply_env_overrides(); - - assert!(config.redis.ssl.is_some()); - assert!(!config.redis.ssl.as_ref().unwrap().insecure); - - unsafe { - env::remove_var("AX_REDIS_SSL_INSECURE"); - } - } - - #[test] - #[serial] - fn test_apply_env_overrides_redis_ssl_combined() { - cleanup_redis_ssl_env_vars(); - let mut config = Config::default(); - unsafe { - env::set_var("AX_REDIS_SSL_CA_CERT_PATH", "/test/ca.crt"); - env::set_var("AX_REDIS_SSL_CLIENT_CERT_PATH", "/test/client.crt"); - env::set_var("AX_REDIS_SSL_CLIENT_KEY_PATH", "/test/client.key"); - env::set_var("AX_REDIS_SSL_INSECURE", "true"); - } - - config.apply_env_overrides(); - - assert!(config.redis.ssl.is_some()); - let ssl = config.redis.ssl.as_ref().unwrap(); - assert_eq!(ssl.ca_cert_path, Some("/test/ca.crt".to_string())); - assert_eq!(ssl.client_cert_path, Some("/test/client.crt".to_string())); - assert_eq!(ssl.client_key_path, Some("/test/client.key".to_string())); - assert!(ssl.insecure); - - unsafe { - env::remove_var("AX_REDIS_SSL_CA_CERT_PATH"); - env::remove_var("AX_REDIS_SSL_CLIENT_CERT_PATH"); - env::remove_var("AX_REDIS_SSL_CLIENT_KEY_PATH"); - env::remove_var("AX_REDIS_SSL_INSECURE"); - } - } - - #[test] - fn test_redis_ssl_config_default_insecure() { - let config = RedisSslConfig { - ca_cert_path: None, - client_cert_path: None, - client_key_path: None, - insecure: false, - }; - assert!(!config.insecure); - } -} +use std::{path::PathBuf, env}; + +use anyhow::Result; +use clap::Parser; +use clap::ValueEnum; +use serde::{Deserialize, Serialize}; + +use crate::waf::actions::captcha::CaptchaProvider; + +/// TLS operating mode +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ValueEnum)] +#[serde(rename_all = "lowercase")] +pub enum TlsMode { + /// TLS is disabled + Disabled, +} + +/// Application operating mode +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ValueEnum)] +#[serde(rename_all = "lowercase")] +pub enum AppMode { + /// Agent mode: Only access rules and monitoring (no proxy) + Agent, + /// Proxy mode: Full reverse proxy functionality + Proxy, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + #[serde(default = "default_mode")] + pub mode: String, + + /// Number of Tokio worker threads. In agent mode, defaults to 0 (single-threaded). + /// Set to a positive number to force multi-threaded runtime. + /// In proxy mode, defaults to number of CPUs. + #[serde(default)] + pub worker_threads: Option, + + // Global server options (moved from server section) + #[serde(default)] + pub redis: RedisConfig, + #[serde(default)] + pub network: NetworkConfig, + #[serde(default, alias = "arxignis")] + pub platform: Gen0SecConfig, + #[serde(default)] + pub geoip: GeoipConfig, + #[serde(default)] + pub content_scanning: ContentScanningCliConfig, + #[serde(default)] + pub logging: LoggingConfig, + #[serde(default)] + pub bpf_stats: BpfStatsConfig, + #[serde(default)] + pub tcp_fingerprint: TcpFingerprintConfig, + #[serde(default)] + pub daemon: DaemonConfig, + #[serde(default)] + pub pingora: PingoraConfig, + #[serde(default)] + pub acme: AcmeConfig, +} + +fn default_mode() -> String { "agent".to_string() } + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ProxyProtocolConfig { + #[serde(default = "default_proxy_protocol_enabled")] + pub enabled: bool, + #[serde(default = "default_proxy_protocol_timeout")] + pub timeout_ms: u64, +} + +fn default_proxy_protocol_enabled() -> bool { false } +fn default_proxy_protocol_timeout() -> u64 { 1000 } + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HealthCheckConfig { + #[serde(default = "default_health_check_enabled")] + pub enabled: bool, + #[serde(default = "default_health_check_endpoint")] + pub endpoint: String, + #[serde(default = "default_health_check_port")] + pub port: String, + #[serde(default = "default_health_check_methods")] + pub methods: Vec, + #[serde(default = "default_health_check_allowed_cidrs")] + pub allowed_cidrs: Vec, +} + +fn default_health_check_enabled() -> bool { true } +fn default_health_check_endpoint() -> String { "/health".to_string() } +fn default_health_check_port() -> String { "0.0.0.0:8080".to_string() } +fn default_health_check_methods() -> Vec { vec!["GET".to_string(), "HEAD".to_string()] } +fn default_health_check_allowed_cidrs() -> Vec { vec![] } + + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct RedisConfig { + #[serde(default)] + pub url: String, + #[serde(default)] + pub prefix: String, + /// Redis SSL/TLS configuration + #[serde(default)] + pub ssl: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RedisSslConfig { + /// Path to CA certificate file (PEM format) + pub ca_cert_path: Option, + /// Path to client certificate file (PEM format, optional) + pub client_cert_path: Option, + /// Path to client private key file (PEM format, optional) + pub client_key_path: Option, + /// Skip certificate verification (for testing with self-signed certs) + #[serde(default)] + pub insecure: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct NetworkConfig { + #[serde(default)] + pub iface: String, + #[serde(default)] + pub ifaces: Vec, + #[serde(default)] + pub disable_xdp: bool, + /// IP version support mode: "ipv4", "ipv6", or "both" (default: "both") + /// Note: XDP requires IPv6 to be enabled at kernel level for attachment, + /// even in IPv4-only mode. Set to "ipv4" to skip XDP if IPv6 cannot be enabled. + #[serde(default = "default_ip_version")] + pub ip_version: String, + /// Firewall backend mode: auto, xdp, nftables, iptables, none + #[serde(default)] + pub firewall_mode: crate::firewall::FirewallMode, +} + +fn default_ip_version() -> String { + "both".to_string() +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct Gen0SecConfig { + #[serde(default)] + pub api_key: String, + #[serde(default)] + pub workspace_id: String, + #[serde(default = "default_base_url")] + pub base_url: String, + /// Threat MMDB database configuration + #[serde(default)] + pub threat: GeoipDatabaseConfig, + #[serde(default = "default_log_sending_enabled")] + pub log_sending_enabled: bool, + #[serde(default = "default_include_response_body")] + pub include_response_body: bool, + #[serde(default = "default_max_body_size")] + pub max_body_size: usize, + #[serde(default)] + pub captcha: CaptchaConfig, +} + +fn default_base_url() -> String { + "https://api.gen0sec.com/v1".to_string() +} + + + +fn default_log_sending_enabled() -> bool { + true +} + +fn default_include_response_body() -> bool { + true +} + +fn default_max_body_size() -> usize { + 1024 * 1024 // 1MB +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct GeoipDatabaseConfig { + /// URL to download GeoIP MMDB file + #[serde(default)] + pub url: String, + /// Optional local path to store/read the GeoIP MMDB file + #[serde(default)] + pub path: Option, + /// Optional custom headers to add to download requests + #[serde(default)] + pub headers: Option>, + /// Refresh interval in seconds (optional, overrides parent refresh_secs if set) + #[serde(default)] + pub refresh_secs: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct GeoipConfig { + /// Country database configuration + #[serde(default)] + pub country: GeoipDatabaseConfig, + /// ASN database configuration + #[serde(default)] + pub asn: GeoipDatabaseConfig, + /// City database configuration + #[serde(default)] + pub city: GeoipDatabaseConfig, + /// How often to refresh the GeoIP MMDB from the remote URL (seconds). Default: 28800 (8 hours) + #[serde(default = "default_geoip_refresh_secs")] + pub refresh_secs: u64, +} + +fn default_geoip_refresh_secs() -> u64 { + 28800 // 8 hours +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ContentScanningCliConfig { + #[serde(default = "default_scanning_enabled")] + pub enabled: bool, + #[serde(default = "default_clamav_server")] + pub clamav_server: String, + #[serde(default = "default_max_file_size")] + pub max_file_size: usize, + #[serde(default)] + pub scan_content_types: Vec, + #[serde(default)] + pub skip_extensions: Vec, + #[serde(default = "default_scan_expression")] + pub scan_expression: String, +} + +fn default_scanning_enabled() -> bool { false } +fn default_clamav_server() -> String { "localhost:3310".to_string() } +fn default_max_file_size() -> usize { 10 * 1024 * 1024 } +fn default_scan_expression() -> String { "http.request.method eq \"POST\" or http.request.method eq \"PUT\"".to_string() } + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct LoggingConfig { + #[serde(default)] + pub level: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct CaptchaConfig { + #[serde(default)] + pub site_key: Option, + #[serde(default)] + pub secret_key: Option, + #[serde(default)] + pub jwt_secret: Option, + #[serde(default)] + pub provider: String, + #[serde(default)] + pub token_ttl: u64, + #[serde(default)] + pub cache_ttl: u64, +} + +impl Config { + pub fn load_from_file(path: &PathBuf) -> Result { + let content = std::fs::read_to_string(path)?; + let config: Config = serde_yaml::from_str(&content)?; + Ok(config) + } + + pub fn default() -> Self { + Self { + mode: "agent".to_string(), + worker_threads: None, + redis: RedisConfig { + url: "redis://127.0.0.1/0".to_string(), + prefix: "ax:synapse".to_string(), + ssl: None, + }, + network: NetworkConfig { + iface: "eth0".to_string(), + ifaces: vec![], + disable_xdp: false, + ip_version: "both".to_string(), + firewall_mode: crate::firewall::FirewallMode::default(), + }, + platform: Gen0SecConfig { + api_key: "".to_string(), + workspace_id: "".to_string(), + base_url: "https://api.gen0sec.com/v1".to_string(), + threat: GeoipDatabaseConfig::default(), + log_sending_enabled: true, + include_response_body: true, + max_body_size: 1024 * 1024, // 1MB + captcha: CaptchaConfig { + site_key: None, + secret_key: None, + jwt_secret: None, + provider: "hcaptcha".to_string(), + token_ttl: 7200, + cache_ttl: 300, + }, + }, + geoip: GeoipConfig { + country: GeoipDatabaseConfig { + url: "".to_string(), + path: None, + headers: None, + refresh_secs: None, + }, + asn: GeoipDatabaseConfig { + url: "".to_string(), + path: None, + headers: None, + refresh_secs: None, + }, + city: GeoipDatabaseConfig { + url: "".to_string(), + path: None, + headers: None, + refresh_secs: None, + }, + refresh_secs: 28800, + }, + content_scanning: ContentScanningCliConfig { + enabled: false, + clamav_server: "localhost:3310".to_string(), + max_file_size: 10 * 1024 * 1024, + scan_content_types: vec![ + "text/html".to_string(), + "application/x-www-form-urlencoded".to_string(), + "multipart/form-data".to_string(), + "application/json".to_string(), + "text/plain".to_string(), + ], + skip_extensions: vec![], + scan_expression: default_scan_expression(), + }, + logging: LoggingConfig { + level: "info".to_string(), + }, + bpf_stats: BpfStatsConfig::default(), + tcp_fingerprint: TcpFingerprintConfig::default(), + daemon: DaemonConfig::default(), + pingora: PingoraConfig::default(), + acme: AcmeConfig::default(), + } + } + + pub fn merge_with_args(&mut self, args: &Args) { + // Override config values with command line arguments if provided + + if !args.ifaces.is_empty() { + self.network.ifaces = args.ifaces.clone(); + } + if args.disable_xdp { + self.network.disable_xdp = true; + } + if let Some(ref mode) = args.firewall_mode { + self.network.firewall_mode = match mode.to_lowercase().as_str() { + "auto" => crate::firewall::FirewallMode::Auto, + "xdp" => crate::firewall::FirewallMode::Xdp, + "nftables" => crate::firewall::FirewallMode::Nftables, + "iptables" => crate::firewall::FirewallMode::Iptables, + "none" => crate::firewall::FirewallMode::None, + _ => { + log::warn!("Unknown firewall mode '{}', using auto", mode); + crate::firewall::FirewallMode::Auto + } + }; + } + if let Some(api_key) = &args.arxignis_api_key { + self.platform.api_key = api_key.clone(); + } + if !args.arxignis_workspace_id.is_empty() { + self.platform.workspace_id = args.arxignis_workspace_id.clone(); + } + if !args.arxignis_base_url.is_empty() && args.arxignis_base_url != "https://api.gen0sec.com/v1" { + self.platform.base_url = args.arxignis_base_url.clone(); + } + if let Some(log_sending_enabled) = args.arxignis_log_sending_enabled { + self.platform.log_sending_enabled = log_sending_enabled; + } + self.platform.include_response_body = args.arxignis_include_response_body; + self.platform.max_body_size = args.arxignis_max_body_size; + if args.captcha_site_key.is_some() { + self.platform.captcha.site_key = args.captcha_site_key.clone(); + } + if args.captcha_secret_key.is_some() { + self.platform.captcha.secret_key = args.captcha_secret_key.clone(); + } + if args.captcha_jwt_secret.is_some() { + self.platform.captcha.jwt_secret = args.captcha_jwt_secret.clone(); + } + if let Some(provider) = &args.captcha_provider { + self.platform.captcha.provider = format!("{:?}", provider).to_lowercase(); + } + + // Proxy protocol configuration overrides + if args.proxy_protocol_enabled { + self.pingora.proxy_protocol.enabled = true; + } + if args.proxy_protocol_timeout != 1000 { + self.pingora.proxy_protocol.timeout_ms = args.proxy_protocol_timeout; + } + + // Daemon configuration overrides + if args.daemon { + self.daemon.enabled = true; + } + if args.daemon_pid_file != "/var/run/synapse.pid" { + self.daemon.pid_file = args.daemon_pid_file.clone(); + } + if args.daemon_working_dir != "/" { + self.daemon.working_directory = args.daemon_working_dir.clone(); + } + if args.daemon_stdout != "/var/log/synapse.out" { + self.daemon.stdout = args.daemon_stdout.clone(); + } + if args.daemon_stderr != "/var/log/synapse.err" { + self.daemon.stderr = args.daemon_stderr.clone(); + } + if args.daemon_user.is_some() { + self.daemon.user = args.daemon_user.clone(); + } + if args.daemon_group.is_some() { + self.daemon.group = args.daemon_group.clone(); + } + + // Redis configuration overrides + if !args.redis_url.is_empty() && args.redis_url != "redis://127.0.0.1/0" { + self.redis.url = args.redis_url.clone(); + } + if !args.redis_prefix.is_empty() && args.redis_prefix != "ax:synapse" { + self.redis.prefix = args.redis_prefix.clone(); + } + } + + pub fn validate_required_fields(&mut self, args: &Args) -> Result<()> { + // Check if arxignis API key is provided - only warn if not provided + // (to support old config format that doesn't have this field) + if args.arxignis_api_key.is_none() && self.platform.api_key.is_empty() { + log::warn!("Gen0Sec API key not provided. Some features may not work."); + } + + Ok(()) + } + + pub fn load_from_args(args: &Args) -> Result { + let mut config = if let Some(config_path) = &args.config { + Self::load_from_file(config_path)? + } else { + Self::default() + }; + + config.merge_with_args(args); + config.apply_env_overrides(); + config.validate_required_fields(args)?; + Ok(config) + } + + pub fn apply_env_overrides(&mut self) { + // Helper function to get env var, preferring non-prefixed version + // AX_ prefix is deprecated and will be removed in a future version + fn get_env(name: &str) -> Option { + // Prefer non-prefixed version + if let Ok(val) = env::var(name) { + return Some(val); + } + // Fallback to AX_ prefix with deprecation warning + let ax_name = format!("AX_{}", name); + if let Ok(val) = env::var(&ax_name) { + log::warn!("Environment variable '{}' is deprecated, use '{}' instead. The AX_ prefix will be removed in a future version.", ax_name, name); + return Some(val); + } + None + } + + // Helper for arxignis vars: name > AX_name > AX_ARXIGNIS_name > ARXIGNIS_name (backward compat) + // ARXIGNIS prefix is deprecated and will be removed in a future version + fn get_env_arxignis(name: &str) -> Option { + // Prefer non-prefixed version + if let Ok(val) = env::var(name) { + return Some(val); + } + // Fallback to AX_ prefix with deprecation warning + let ax_name = format!("AX_{}", name); + if let Ok(val) = env::var(&ax_name) { + log::warn!("Environment variable '{}' is deprecated, use '{}' instead. The AX_ prefix will be removed in a future version.", ax_name, name); + return Some(val); + } + // Fallback to AX_ARXIGNIS_ prefix with deprecation warning + let ax_arxignis_name = format!("AX_ARXIGNIS_{}", name); + if let Ok(val) = env::var(&ax_arxignis_name) { + log::warn!("Environment variable '{}' is deprecated, use '{}' instead. The ARXIGNIS prefix will be removed in a future version.", ax_arxignis_name, name); + return Some(val); + } + // Fallback to ARXIGNIS_ prefix with deprecation warning + let arxignis_name = format!("ARXIGNIS_{}", name); + if let Ok(val) = env::var(&arxignis_name) { + log::warn!("Environment variable '{}' is deprecated, use '{}' instead. The ARXIGNIS prefix will be removed in a future version.", arxignis_name, name); + return Some(val); + } + None + } + + // Mode override + if let Some(val) = get_env("MODE") { + self.mode = val; + } + + // Redis configuration overrides + if let Some(val) = get_env("REDIS_URL") { + self.redis.url = val; + } + if let Some(val) = get_env("REDIS_PREFIX") { + self.redis.prefix = val; + } + + // Redis SSL configuration overrides + // Read all SSL environment variables once + let ca_cert_path = get_env("REDIS_SSL_CA_CERT_PATH"); + let client_cert_path = get_env("REDIS_SSL_CLIENT_CERT_PATH"); + let client_key_path = get_env("REDIS_SSL_CLIENT_KEY_PATH"); + let insecure_val = get_env("REDIS_SSL_INSECURE"); + + // If any SSL env var is set, ensure SSL config exists + if ca_cert_path.is_some() + || client_cert_path.is_some() + || client_key_path.is_some() + || insecure_val.is_some() { + + // Create SSL config if it doesn't exist + if self.redis.ssl.is_none() { + // Parse insecure value if provided, default to false + let insecure_default = insecure_val + .as_ref() + .and_then(|v| v.parse::().ok()) + .unwrap_or(false); + + self.redis.ssl = Some(RedisSslConfig { + ca_cert_path: None, + client_cert_path: None, + client_key_path: None, + insecure: insecure_default, + }); + } + + // Update the SSL config with values from environment variables + let ssl = self.redis.ssl.as_mut().expect("SSL config should exist here"); + if let Some(val) = ca_cert_path { + ssl.ca_cert_path = Some(val); + } + if let Some(val) = client_cert_path { + ssl.client_cert_path = Some(val); + } + if let Some(val) = client_key_path { + ssl.client_key_path = Some(val); + } + if let Some(val) = insecure_val { + if let Ok(insecure) = val.parse::() { + ssl.insecure = insecure; + } + } + } + + // Network configuration overrides + if let Some(val) = get_env("NETWORK_IFACE") { + self.network.iface = val; + } + if let Some(val) = get_env("NETWORK_IFACES") { + self.network.ifaces = val.split(',').map(|s| s.trim().to_string()).collect(); + } + if let Some(val) = get_env("NETWORK_DISABLE_XDP") { + self.network.disable_xdp = val.parse().unwrap_or(false); + } + if let Some(val) = get_env("FIREWALL_MODE").or_else(|| get_env("NETWORK_FIREWALL_MODE")) { + self.network.firewall_mode = match val.to_lowercase().as_str() { + "auto" => crate::firewall::FirewallMode::Auto, + "xdp" => crate::firewall::FirewallMode::Xdp, + "nftables" => crate::firewall::FirewallMode::Nftables, + "iptables" => crate::firewall::FirewallMode::Iptables, + "none" => crate::firewall::FirewallMode::None, + _ => crate::firewall::FirewallMode::Auto, + }; + } + if let Some(val) = get_env("NETWORK_IP_VERSION") { + // Validate ip_version value + match val.as_str() { + "ipv4" | "ipv6" | "both" => { + self.network.ip_version = val; + } + _ => { + log::warn!("Invalid NETWORK_IP_VERSION value '{}', using default 'both'. Valid values: ipv4, ipv6, both", val); + } + } + } + + // Gen0Sec configuration overrides + // Supports: AX_API_KEY, API_KEY, AX_ARXIGNIS_API_KEY, ARXIGNIS_API_KEY (backward compat) + if let Some(val) = get_env_arxignis("API_KEY") { + self.platform.api_key = val; + } + if let Some(val) = get_env_arxignis("WORKSPACE_ID") { + self.platform.workspace_id = val; + } + if let Some(val) = get_env_arxignis("BASE_URL") { + self.platform.base_url = val; + } + if let Some(val) = get_env_arxignis("LOG_SENDING_ENABLED") { + if let Ok(parsed) = val.parse::() { + self.platform.log_sending_enabled = parsed; + } + } + if let Some(val) = get_env_arxignis("INCLUDE_RESPONSE_BODY") { + self.platform.include_response_body = val.parse().unwrap_or(true); + } + if let Some(val) = get_env_arxignis("MAX_BODY_SIZE") { + self.platform.max_body_size = val.parse().unwrap_or(1024 * 1024); + } + + // Logging configuration overrides + if let Some(val) = get_env("LOGGING_LEVEL") { + self.logging.level = val; + } + + // Content scanning overrides + if let Some(val) = get_env("CONTENT_SCANNING_ENABLED") { + self.content_scanning.enabled = val.parse().unwrap_or(false); + } + if let Some(val) = get_env("CLAMAV_SERVER") { + self.content_scanning.clamav_server = val; + } + if let Some(val) = get_env("CONTENT_MAX_FILE_SIZE") { + self.content_scanning.max_file_size = val.parse().unwrap_or(10 * 1024 * 1024); + } + if let Some(val) = get_env("CONTENT_SCAN_CONTENT_TYPES") { + self.content_scanning.scan_content_types = val.split(',').map(|s| s.trim().to_string()).collect(); + } + if let Some(val) = get_env("CONTENT_SKIP_EXTENSIONS") { + self.content_scanning.skip_extensions = val.split(',').map(|s| s.trim().to_string()).collect(); + } + if let Some(val) = get_env("CONTENT_SCAN_EXPRESSION") { + self.content_scanning.scan_expression = val; + } + + // Captcha configuration overrides + if let Some(val) = get_env("CAPTCHA_SITE_KEY") { + self.platform.captcha.site_key = Some(val); + } + if let Some(val) = get_env("CAPTCHA_SECRET_KEY") { + self.platform.captcha.secret_key = Some(val); + } + if let Some(val) = get_env("CAPTCHA_JWT_SECRET") { + self.platform.captcha.jwt_secret = Some(val); + } + if let Some(val) = get_env("CAPTCHA_PROVIDER") { + self.platform.captcha.provider = val; + } + if let Some(val) = get_env("CAPTCHA_TOKEN_TTL") { + self.platform.captcha.token_ttl = val.parse().unwrap_or(7200); + } + if let Some(val) = get_env("CAPTCHA_CACHE_TTL") { + self.platform.captcha.cache_ttl = val.parse().unwrap_or(300); + } + + // Proxy protocol configuration overrides + if let Some(val) = get_env("PROXY_PROTOCOL_ENABLED") { + self.pingora.proxy_protocol.enabled = val.parse().unwrap_or(false); + } + if let Some(val) = get_env("PROXY_PROTOCOL_TIMEOUT") { + self.pingora.proxy_protocol.timeout_ms = val.parse().unwrap_or(1000); + } + + // Daemon configuration overrides + if let Some(val) = get_env("DAEMON_ENABLED") { + self.daemon.enabled = val.parse().unwrap_or(false); + } + if let Some(val) = get_env("DAEMON_PID_FILE") { + self.daemon.pid_file = val; + } + if let Some(val) = get_env("DAEMON_WORKING_DIRECTORY") { + self.daemon.working_directory = val; + } + if let Some(val) = get_env("DAEMON_STDOUT") { + self.daemon.stdout = val; + } + if let Some(val) = get_env("DAEMON_STDERR") { + self.daemon.stderr = val; + } + if let Some(val) = get_env("DAEMON_USER") { + self.daemon.user = Some(val); + } + if let Some(val) = get_env("DAEMON_GROUP") { + self.daemon.group = Some(val); + } + if let Some(val) = get_env("DAEMON_CHOWN_PID_FILE") { + self.daemon.chown_pid_file = val.parse().unwrap_or(true); + } + } +} + +#[derive(Parser, Debug, Clone)] +#[command(author, version, about, long_about = None)] +pub struct Args { + /// Path to configuration file (YAML format) + #[arg(long, short = 'c')] + pub config: Option, + + /// Path to security rules configuration file (access rules and WAF rules) + /// Used when no Gen0Sec API key is provided as a fallback + #[arg(long, default_value = "security_rules.yaml")] + pub security_rules_config: PathBuf, + + /// Clear a specific certificate from local filesystem and Redis + #[arg(long)] + pub clear_certificate: Option, + + /// Redis connection URL for ACME cache storage. + #[arg(long, default_value = "redis://127.0.0.1/0")] + pub redis_url: String, + + /// Namespace prefix for Redis ACME cache entries. + #[arg(long, default_value = "ax:synapse")] + pub redis_prefix: String, + + /// The network interface to attach the XDP program to. + #[arg(short, long, default_value = "eth0")] + pub iface: String, + + /// Additional network interfaces for XDP attach (comma-separated). If set, overrides --iface. + #[arg(long, value_delimiter = ',', num_args = 0..)] + pub ifaces: Vec, + + #[arg(long)] + pub arxignis_api_key: Option, + + /// Base URL for Gen0Sec API. + #[arg(long, default_value = "https://api.gen0sec.com/v1")] + pub arxignis_base_url: String, + + /// Workspace ID for agent identity + #[arg(long, default_value = "")] + pub arxignis_workspace_id: String, + + /// Enable sending access logs to arxignis server + #[arg(long)] + pub arxignis_log_sending_enabled: Option, + + /// Include response body in access logs + #[arg(long, default_value_t = true)] + pub arxignis_include_response_body: bool, + + /// Maximum size for request/response bodies in access logs (bytes) + #[arg(long, default_value = "1048576")] + pub arxignis_max_body_size: usize, + + /// Log level (error, warn, info, debug, trace) + #[arg(long, value_enum, default_value_t = LogLevel::Info)] + pub log_level: LogLevel, + + /// Disable XDP packet filtering (run without BPF/XDP) + #[arg(long, default_value_t = false)] + pub disable_xdp: bool, + + /// Firewall backend mode (auto, xdp, nftables, iptables, none) + #[arg(long)] + pub firewall_mode: Option, + + /// Captcha site key for security verification + #[arg(long)] + pub captcha_site_key: Option, + + /// Captcha secret key for security verification + #[arg(long)] + pub captcha_secret_key: Option, + + /// JWT secret key for captcha token signing + #[arg(long)] + pub captcha_jwt_secret: Option, + + /// Captcha provider (hcaptcha, recaptcha, turnstile) + #[arg(long, value_enum)] + pub captcha_provider: Option, + + + /// Captcha token TTL in seconds + #[arg(long, default_value = "7200")] + pub captcha_token_ttl: u64, + + /// Captcha validation cache TTL in seconds + #[arg(long, default_value = "300")] + pub captcha_cache_ttl: u64, + + /// Enable PROXY protocol support for TCP connections + #[arg(long, default_value_t = false)] + pub proxy_protocol_enabled: bool, + + /// PROXY protocol timeout in milliseconds + #[arg(long, default_value = "1000")] + pub proxy_protocol_timeout: u64, + + /// Run as daemon in background + #[arg(long, short = 'd', default_value_t = false)] + pub daemon: bool, + + /// PID file path for daemon mode + #[arg(long, default_value = "/var/run/synapse.pid")] + pub daemon_pid_file: String, + + /// Working directory for daemon mode + #[arg(long, default_value = "/")] + pub daemon_working_dir: String, + + /// Stdout log file for daemon mode + #[arg(long, default_value = "/var/log/synapse.out")] + pub daemon_stdout: String, + + /// Stderr log file for daemon mode + #[arg(long, default_value = "/var/log/synapse.err")] + pub daemon_stderr: String, + + /// User to run daemon as + #[arg(long)] + pub daemon_user: Option, + + /// Group to run daemon as + #[arg(long)] + pub daemon_group: Option, +} + +#[derive(Copy, Clone, Debug, ValueEnum)] +pub enum LogLevel { + Error, + Warn, + Info, + Debug, + Trace, +} + +impl LogLevel { + pub fn to_level_filter(self) -> log::LevelFilter { + match self { + LogLevel::Error => log::LevelFilter::Error, + LogLevel::Warn => log::LevelFilter::Warn, + LogLevel::Info => log::LevelFilter::Info, + LogLevel::Debug => log::LevelFilter::Debug, + LogLevel::Trace => log::LevelFilter::Trace, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BpfStatsConfig { + #[serde(default = "default_bpf_stats_enabled")] + pub enabled: bool, + #[serde(default = "default_bpf_stats_log_interval")] + pub log_interval_secs: u64, + #[serde(default = "default_bpf_stats_enable_dropped_ip_events")] + pub enable_dropped_ip_events: bool, + #[serde(default = "default_bpf_stats_dropped_ip_events_interval")] + pub dropped_ip_events_interval_secs: u64, +} + +fn default_bpf_stats_enabled() -> bool { true } +fn default_bpf_stats_log_interval() -> u64 { 60 } +fn default_bpf_stats_enable_dropped_ip_events() -> bool { true } +fn default_bpf_stats_dropped_ip_events_interval() -> u64 { 30 } + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct TcpFingerprintConfig { + #[serde(default = "default_tcp_fingerprint_enabled")] + pub enabled: bool, + #[serde(default = "default_tcp_fingerprint_log_interval")] + pub log_interval_secs: u64, + #[serde(default = "default_tcp_fingerprint_enable_fingerprint_events")] + pub enable_fingerprint_events: bool, + #[serde(default = "default_tcp_fingerprint_events_interval")] + pub fingerprint_events_interval_secs: u64, + #[serde(default = "default_tcp_fingerprint_min_packet_count")] + pub min_packet_count: u32, + #[serde(default = "default_tcp_fingerprint_min_connection_duration")] + pub min_connection_duration_secs: u64, +} + +fn default_tcp_fingerprint_enabled() -> bool { true } +fn default_tcp_fingerprint_log_interval() -> u64 { 60 } +fn default_tcp_fingerprint_enable_fingerprint_events() -> bool { true } +fn default_tcp_fingerprint_events_interval() -> u64 { 30 } +fn default_tcp_fingerprint_min_packet_count() -> u32 { 3 } +fn default_tcp_fingerprint_min_connection_duration() -> u64 { 1 } + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct DaemonConfig { + #[serde(default = "default_daemon_enabled")] + pub enabled: bool, + #[serde(default = "default_daemon_pid_file")] + pub pid_file: String, + #[serde(default = "default_daemon_working_directory")] + pub working_directory: String, + #[serde(default = "default_daemon_stdout")] + pub stdout: String, + #[serde(default = "default_daemon_stderr")] + pub stderr: String, + pub user: Option, + pub group: Option, + #[serde(default = "default_daemon_chown_pid_file")] + pub chown_pid_file: bool, +} + +fn default_daemon_enabled() -> bool { false } +fn default_daemon_pid_file() -> String { "/var/run/synapse.pid".to_string() } +fn default_daemon_working_directory() -> String { "/".to_string() } +fn default_daemon_stdout() -> String { "/var/log/synapse.out".to_string() } +fn default_daemon_stderr() -> String { "/var/log/synapse.err".to_string() } +fn default_daemon_chown_pid_file() -> bool { true } + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PingoraConfig { + #[serde(default)] + pub proxy_address_http: String, + #[serde(default)] + pub proxy_address_tls: Option, + #[serde(default)] + pub proxy_certificates: Option, + #[serde(default = "default_pingora_tls_grade")] + pub proxy_tls_grade: String, + #[serde(default)] + pub default_certificate: Option, + #[serde(default)] + pub upstreams_conf: String, + #[serde(default)] + pub config_address: String, + #[serde(default = "default_pingora_config_api_enabled")] + pub config_api_enabled: bool, + #[serde(default)] + pub master_key: String, + #[serde(default = "default_pingora_log_level")] + pub log_level: String, + #[serde(default = "default_pingora_healthcheck_method")] + pub healthcheck_method: String, + #[serde(default = "default_pingora_healthcheck_interval")] + pub healthcheck_interval: u16, + #[serde(default)] + pub proxy_protocol: ProxyProtocolConfig, +} + +fn default_pingora_tls_grade() -> String { "medium".to_string() } +fn default_pingora_config_api_enabled() -> bool { true } +fn default_pingora_log_level() -> String { "debug".to_string() } +fn default_pingora_healthcheck_method() -> String { "HEAD".to_string() } +fn default_pingora_healthcheck_interval() -> u16 { 2 } + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AcmeConfig { + /// Enable embedded ACME server + #[serde(default = "default_acme_enabled")] + pub enabled: bool, + /// Port for ACME server (e.g., 9180) + #[serde(default = "default_acme_port")] + pub port: u16, + /// Email for ACME account + pub email: Option, + /// Storage path for certificates + #[serde(default = "default_acme_storage_path")] + pub storage_path: String, + /// Storage type: "file" or "redis" (defaults to "file", or "redis" if redis_url is set) + pub storage_type: Option, + /// Use development/staging ACME server + #[serde(default)] + pub development: bool, + /// Redis URL for storage (optional, uses global redis.url if not set) + pub redis_url: Option, +} + +impl Default for AcmeConfig { + fn default() -> Self { + Self { + enabled: default_acme_enabled(), + port: default_acme_port(), + email: None, + storage_path: default_acme_storage_path(), + storage_type: None, + development: false, + redis_url: None, + } + } +} + +fn default_acme_enabled() -> bool { false } +fn default_acme_port() -> u16 { 9180 } +fn default_acme_storage_path() -> String { "/tmp/synapse-acme".to_string() } + +impl PingoraConfig { + /// Convert PingoraConfig to AppConfig for compatibility with old proxy system + pub fn to_app_config(&self) -> crate::utils::structs::AppConfig { + let mut app_config = crate::utils::structs::AppConfig::default(); + app_config.proxy_address_http = self.proxy_address_http.clone(); + app_config.proxy_address_tls = self.proxy_address_tls.clone(); + app_config.proxy_certificates = self.proxy_certificates.clone(); + app_config.proxy_tls_grade = Some(self.proxy_tls_grade.clone()); + app_config.default_certificate = self.default_certificate.clone(); + app_config.upstreams_conf = self.upstreams_conf.clone(); + app_config.config_address = self.config_address.clone(); + app_config.config_api_enabled = self.config_api_enabled; + app_config.master_key = self.master_key.clone(); + app_config.healthcheck_method = self.healthcheck_method.clone(); + app_config.healthcheck_interval = self.healthcheck_interval; + app_config.proxy_protocol_enabled = self.proxy_protocol.enabled; + + // Parse config_address to local_server + if let Some((ip, port_str)) = self.config_address.split_once(':') { + if let Ok(port) = port_str.parse::() { + app_config.local_server = Some((ip.to_string(), port)); + } + } + + // Parse proxy_address_tls to proxy_port_tls + if let Some(ref tls_addr) = self.proxy_address_tls { + if let Some((_, port_str)) = tls_addr.split_once(':') { + if let Ok(port) = port_str.parse::() { + app_config.proxy_port_tls = Some(port); + } + } + } + + app_config + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + use serial_test::serial; + + #[test] + fn test_redis_ssl_config_deserialize() { + let yaml = r#" +ca_cert_path: "/path/to/ca.crt" +client_cert_path: "/path/to/client.crt" +client_key_path: "/path/to/client.key" +insecure: true +"#; + let config: RedisSslConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.ca_cert_path, Some("/path/to/ca.crt".to_string())); + assert_eq!(config.client_cert_path, Some("/path/to/client.crt".to_string())); + assert_eq!(config.client_key_path, Some("/path/to/client.key".to_string())); + assert!(config.insecure); + } + + #[test] + fn test_redis_ssl_config_deserialize_minimal() { + let yaml = r#" +insecure: false +"#; + let config: RedisSslConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.ca_cert_path, None); + assert_eq!(config.client_cert_path, None); + assert_eq!(config.client_key_path, None); + assert!(!config.insecure); + } + + #[test] + fn test_redis_config_with_ssl() { + let yaml = r#" +url: "rediss://localhost:6379" +prefix: "test:prefix" +ssl: + ca_cert_path: "/path/to/ca.crt" + insecure: false +"#; + let config: RedisConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.url, "rediss://localhost:6379"); + assert_eq!(config.prefix, "test:prefix"); + assert!(config.ssl.is_some()); + let ssl = config.ssl.unwrap(); + assert_eq!(ssl.ca_cert_path, Some("/path/to/ca.crt".to_string())); + assert!(!ssl.insecure); + } + + #[test] + fn test_redis_config_without_ssl() { + let yaml = r#" +url: "redis://localhost:6379" +prefix: "test:prefix" +"#; + let config: RedisConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.url, "redis://localhost:6379"); + assert_eq!(config.prefix, "test:prefix"); + assert!(config.ssl.is_none()); + } + + #[test] + fn test_redis_config_default() { + let config = RedisConfig::default(); + assert_eq!(config.url, ""); + assert_eq!(config.prefix, ""); + assert!(config.ssl.is_none()); + } + + // Helper function to clean up SSL environment variables for test isolation + fn cleanup_redis_ssl_env_vars() { + unsafe { + env::remove_var("AX_REDIS_SSL_CA_CERT_PATH"); + env::remove_var("AX_REDIS_SSL_CLIENT_CERT_PATH"); + env::remove_var("AX_REDIS_SSL_CLIENT_KEY_PATH"); + env::remove_var("AX_REDIS_SSL_INSECURE"); + } + } + + #[test] + #[serial] + fn test_apply_env_overrides_redis_ssl_ca_cert() { + cleanup_redis_ssl_env_vars(); + + let mut config = Config::default(); + unsafe { + env::set_var("AX_REDIS_SSL_CA_CERT_PATH", "/test/ca.crt"); + } + + config.apply_env_overrides(); + + assert!(config.redis.ssl.is_some()); + assert_eq!(config.redis.ssl.as_ref().unwrap().ca_cert_path, Some("/test/ca.crt".to_string())); + + unsafe { + env::remove_var("AX_REDIS_SSL_CA_CERT_PATH"); + } + } + + #[test] + #[serial] + fn test_apply_env_overrides_redis_ssl_client_cert() { + cleanup_redis_ssl_env_vars(); + + let mut config = Config::default(); + unsafe { + env::set_var("AX_REDIS_SSL_CLIENT_CERT_PATH", "/test/client.crt"); + env::set_var("AX_REDIS_SSL_CLIENT_KEY_PATH", "/test/client.key"); + } + + config.apply_env_overrides(); + + assert!(config.redis.ssl.is_some()); + let ssl = config.redis.ssl.as_ref().unwrap(); + assert_eq!(ssl.client_cert_path, Some("/test/client.crt".to_string())); + assert_eq!(ssl.client_key_path, Some("/test/client.key".to_string())); + + unsafe { + env::remove_var("AX_REDIS_SSL_CLIENT_CERT_PATH"); + env::remove_var("AX_REDIS_SSL_CLIENT_KEY_PATH"); + } + } + + #[test] + #[serial] + fn test_apply_env_overrides_redis_ssl_insecure() { + cleanup_redis_ssl_env_vars(); + + let mut config = Config::default(); + unsafe { + env::set_var("AX_REDIS_SSL_INSECURE", "true"); + } + + config.apply_env_overrides(); + + assert!(config.redis.ssl.is_some()); + assert!(config.redis.ssl.as_ref().unwrap().insecure); + + unsafe { + env::remove_var("AX_REDIS_SSL_INSECURE"); + } + } + + #[test] + #[serial] + fn test_apply_env_overrides_redis_ssl_insecure_false() { + cleanup_redis_ssl_env_vars(); + + let mut config = Config::default(); + unsafe { + env::set_var("AX_REDIS_SSL_INSECURE", "false"); + } + + config.apply_env_overrides(); + + assert!(config.redis.ssl.is_some()); + assert!(!config.redis.ssl.as_ref().unwrap().insecure); + + unsafe { + env::remove_var("AX_REDIS_SSL_INSECURE"); + } + } + + #[test] + #[serial] + fn test_apply_env_overrides_redis_ssl_combined() { + cleanup_redis_ssl_env_vars(); + let mut config = Config::default(); + unsafe { + env::set_var("AX_REDIS_SSL_CA_CERT_PATH", "/test/ca.crt"); + env::set_var("AX_REDIS_SSL_CLIENT_CERT_PATH", "/test/client.crt"); + env::set_var("AX_REDIS_SSL_CLIENT_KEY_PATH", "/test/client.key"); + env::set_var("AX_REDIS_SSL_INSECURE", "true"); + } + + config.apply_env_overrides(); + + assert!(config.redis.ssl.is_some()); + let ssl = config.redis.ssl.as_ref().unwrap(); + assert_eq!(ssl.ca_cert_path, Some("/test/ca.crt".to_string())); + assert_eq!(ssl.client_cert_path, Some("/test/client.crt".to_string())); + assert_eq!(ssl.client_key_path, Some("/test/client.key".to_string())); + assert!(ssl.insecure); + + unsafe { + env::remove_var("AX_REDIS_SSL_CA_CERT_PATH"); + env::remove_var("AX_REDIS_SSL_CLIENT_CERT_PATH"); + env::remove_var("AX_REDIS_SSL_CLIENT_KEY_PATH"); + env::remove_var("AX_REDIS_SSL_INSECURE"); + } + } + + #[test] + fn test_redis_ssl_config_default_insecure() { + let config = RedisSslConfig { + ca_cert_path: None, + client_cert_path: None, + client_key_path: None, + insecure: false, + }; + assert!(!config.insecure); + } +} diff --git a/src/content_scanning/mod.rs b/src/content_scanning/mod.rs index b607526..bd220c1 100644 --- a/src/content_scanning/mod.rs +++ b/src/content_scanning/mod.rs @@ -1,652 +1,652 @@ -use anyhow::{Result, anyhow}; -use clamav_tcp::scan; -use serde::{Deserialize, Serialize}; -use std::io::Cursor; -use std::sync::{Arc, RwLock, OnceLock}; -use std::collections::HashMap; -use std::net::SocketAddr; -use hyper::http::request::Parts; -use wirefilter::{ExecutionContext, Scheme, Filter}; -use bytes::Bytes; -use multer::Multipart; - -/// Content scanning configuration -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ContentScanningConfig { - /// Enable or disable content scanning - pub enabled: bool, - /// ClamAV server address (e.g., "localhost:3310") - pub clamav_server: String, - /// Maximum file size to scan in bytes (default: 10MB) - pub max_file_size: usize, - /// Content types to scan (empty means scan all) - pub scan_content_types: Vec, - /// Skip scanning for specific file extensions - pub skip_extensions: Vec, - /// Wirefilter expression to determine when to scan - #[serde(default = "default_scan_expression")] - pub scan_expression: String, -} - -fn default_scan_expression() -> String { - "http.request.method eq \"POST\" or http.request.method eq \"PUT\"".to_string() -} - -impl Default for ContentScanningConfig { - fn default() -> Self { - Self { - enabled: false, - clamav_server: "localhost:3310".to_string(), - max_file_size: 10 * 1024 * 1024, // 10MB - scan_content_types: vec![ - "text/html".to_string(), - "application/x-www-form-urlencoded".to_string(), - "multipart/form-data".to_string(), - "application/json".to_string(), - "text/plain".to_string(), - ], - skip_extensions: vec![], - scan_expression: default_scan_expression(), - } - } -} - -/// Content scanning result -#[derive(Debug, Clone)] -pub struct ScanResult { - /// Whether malware was detected - pub malware_detected: bool, - /// Malware signature name if detected - pub signature: Option, - /// Error message if scanning failed - pub error: Option, -} - -/// Content scanner implementation -pub struct ContentScanner { - config: Arc>, - scheme: Arc, - filter: Arc>>, -} - -/// Extract boundary from Content-Type header for multipart content -pub fn extract_multipart_boundary(content_type: &str) -> Option { - // Content-Type format: multipart/form-data; boundary=----WebKitFormBoundary... - if !content_type.to_lowercase().contains("multipart/") { - return None; - } - - for part in content_type.split(';') { - let trimmed = part.trim(); - let lower = trimmed.to_lowercase(); - if lower.starts_with("boundary=") { - // Find the actual position of "boundary=" in the original string (case-insensitive) - if let Some(eq_pos) = trimmed.to_lowercase().find("boundary=") { - let boundary = trimmed[eq_pos + 9..].trim(); - // Remove quotes if present - let boundary = boundary.trim_matches('"').trim_matches('\''); - return Some(boundary.to_string()); - } - } - } - - None -} - -impl ContentScanner { - /// Create a new content scanner - pub fn new(config: ContentScanningConfig) -> Self { - let scheme = Self::create_scheme(); - let filter = Self::compile_filter(&scheme, &config.scan_expression); - - Self { - config: Arc::new(RwLock::new(config)), - scheme: Arc::new(scheme), - filter: Arc::new(RwLock::new(filter)), - } - } - - /// Create the wirefilter scheme for content scanning - fn create_scheme() -> Scheme { - let builder = wirefilter::Scheme! { - http.request.method: Bytes, - http.request.path: Bytes, - http.request.content_type: Bytes, - http.request.content_length: Int, - }; - builder.build() - } - - /// Compile the scan expression filter - fn compile_filter(scheme: &Scheme, expression: &str) -> Option { - if expression.is_empty() { - return None; - } - - match scheme.parse(expression) { - Ok(ast) => Some(ast.compile()), - Err(e) => { - log::error!("Failed to compile content scanning expression '{}': {}", expression, e); - None - } - } - } - - /// Update scanner configuration - pub fn update_config(&self, config: ContentScanningConfig) { - let new_filter = Self::compile_filter(&self.scheme, &config.scan_expression); - - if let Ok(mut guard) = self.config.write() { - *guard = config; - } - if let Ok(mut guard) = self.filter.write() { - *guard = new_filter; - } - } - - /// Check if content scanning should be performed for this request - pub fn should_scan(&self, req_parts: &Parts, body_bytes: &[u8], _peer_addr: SocketAddr) -> bool { - let config = match self.config.read() { - Ok(guard) => guard.clone(), - Err(_) => return false, - }; - - if !config.enabled { - log::debug!("Content scanning disabled"); - return false; - } - - log::debug!("Checking if should scan request: method={}, path={}, body_size={}", - req_parts.method, req_parts.uri.path(), body_bytes.len()); - - // Check wirefilter expression first - let filter_guard = match self.filter.read() { - Ok(guard) => guard, - Err(_) => return false, - }; - - if let Some(ref filter) = *filter_guard { - let mut ctx = ExecutionContext::new(&self.scheme); - - // Set request fields - let method = req_parts.method.as_str(); - let path = req_parts.uri.path(); - let content_type = req_parts.headers - .get("content-type") - .and_then(|h| h.to_str().ok()) - .unwrap_or(""); - let content_length = body_bytes.len() as i64; - - if ctx.set_field_value(self.scheme.get_field("http.request.method").unwrap(), method).is_err() { - return false; - } - if ctx.set_field_value(self.scheme.get_field("http.request.path").unwrap(), path).is_err() { - return false; - } - if ctx.set_field_value(self.scheme.get_field("http.request.content_type").unwrap(), content_type).is_err() { - return false; - } - if ctx.set_field_value(self.scheme.get_field("http.request.content_length").unwrap(), content_length).is_err() { - return false; - } - - // Execute filter - match filter.execute(&ctx) { - Ok(result) => { - if !result { - log::debug!("Skipping content scan: expression does not match"); - return false; - } else { - log::debug!("Expression matched, proceeding with content scan checks"); - log::debug!("Expression result: {:?}", result); - } - } - Err(e) => { - log::error!("Failed to execute content scanning expression: {}", e); - return false; - } - } - } else { - log::debug!("No scan expression configured, allowing scan"); - } - - // Check if body is too large - if body_bytes.len() > config.max_file_size { - log::debug!("Skipping content scan: body too large ({} bytes)", body_bytes.len()); - return false; - } - - // Check content type - if let Some(content_type) = req_parts.headers.get("content-type") { - if let Ok(content_type_str) = content_type.to_str() { - let content_type_lower = content_type_str.to_lowercase(); - - // If specific content types are configured, only scan those - if !config.scan_content_types.is_empty() { - let should_scan = config.scan_content_types.iter() - .any(|ct| content_type_lower.contains(ct)); - if !should_scan { - log::debug!("Skipping content scan: content type '{}' not in scan list: {:?}", - content_type_str, config.scan_content_types); - return false; - } else { - log::debug!("Content type '{}' matches scan list", content_type_str); - } - } - - // Skip certain content types - if content_type_lower.contains("image/") || - content_type_lower.contains("video/") || - content_type_lower.contains("audio/") { - log::debug!("Skipping content scan: binary content type {}", content_type_str); - return false; - } - } - } - - // Check file extension from URL path - if let Some(path) = req_parts.uri.path().split('/').last() { - if let Some(extension) = std::path::Path::new(path).extension() { - if let Some(ext_str) = extension.to_str() { - let ext_lower = format!(".{}", ext_str.to_lowercase()); - if config.skip_extensions.contains(&ext_lower) { - log::debug!("Skipping content scan: file extension {} in skip list", ext_lower); - return false; - } - } - } - } - - log::debug!("All checks passed, will scan content"); - true - } - - /// Scan content for malware - pub async fn scan_content(&self, body_bytes: &[u8]) -> Result { - let config = match self.config.read() { - Ok(guard) => guard.clone(), - Err(_) => return Err(anyhow!("Failed to read scanner config")), - }; - - if !config.enabled { - return Ok(ScanResult { - malware_detected: false, - signature: None, - error: None, - }); - } - - self.scan_bytes(&config.clamav_server, body_bytes).await - } - - /// Internal method to scan bytes with ClamAV - async fn scan_bytes(&self, clamav_server: &str, data: &[u8]) -> Result { - // Create a cursor over the body bytes for scanning - let mut cursor = Cursor::new(data); - - // Perform the scan - match scan(clamav_server, &mut cursor, None) { - Ok(result) => { - // Check if malware was detected using the new API - if !result.is_infected { - Ok(ScanResult { - malware_detected: false, - signature: None, - error: None, - }) - } else { - // Extract signature name from detected_infections - let signature = if !result.detected_infections.is_empty() { - Some(result.detected_infections.join(", ")) - } else { - None - }; - - Ok(ScanResult { - malware_detected: true, - signature, - error: None, - }) - } - } - Err(e) => { - log::error!("ClamAV scan failed: {}", e); - Ok(ScanResult { - malware_detected: false, - signature: None, - error: Some(format!("Scan failed: {}", e)), - }) - } - } - } - - /// Scan multipart content for malware by parsing parts and scanning each individually - pub async fn scan_multipart_content(&self, body_bytes: &[u8], boundary: &str) -> Result { - let config = match self.config.read() { - Ok(guard) => guard.clone(), - Err(_) => return Err(anyhow!("Failed to read scanner config")), - }; - - if !config.enabled { - return Ok(ScanResult { - malware_detected: false, - signature: None, - error: None, - }); - } - - log::debug!("Parsing multipart body with boundary: {}", boundary); - - // Create a multipart parser - let stream = futures::stream::once(async move { - Result::::Ok(Bytes::copy_from_slice(body_bytes)) - }); - - let mut multipart = Multipart::new(stream, boundary); - - let mut parts_scanned = 0; - let mut parts_failed = 0; - - // Iterate over each part in the multipart body - while let Some(field) = multipart.next_field().await.map_err(|e| anyhow!("Failed to read multipart field: {}", e))? { - let field_name = field.name().unwrap_or("").to_string(); - let field_filename = field.file_name().map(|s| s.to_string()); - let field_content_type = field.content_type().map(|m| m.to_string()); - - log::debug!("Scanning multipart field: name={}, filename={:?}, content_type={:?}", - field_name, field_filename, field_content_type); - - // Read the entire field into bytes - let field_bytes = field.bytes().await.map_err(|e| anyhow!("Failed to read field bytes: {}", e))?; - - // Skip empty fields - if field_bytes.is_empty() { - log::debug!("Skipping empty multipart field: {}", field_name); - continue; - } - - // Check if field size exceeds max_file_size - if field_bytes.len() > config.max_file_size { - log::debug!("Skipping multipart field '{}': size {} exceeds max_file_size {}", - field_name, field_bytes.len(), config.max_file_size); - continue; - } - - parts_scanned += 1; - - // Scan this part - match self.scan_bytes(&config.clamav_server, &field_bytes).await { - Ok(result) => { - if result.malware_detected { - log::info!("Malware detected in multipart field '{}' (filename: {:?}): signature {:?}", - field_name, field_filename, result.signature); - - // Return immediately on first malware detection - return Ok(ScanResult { - malware_detected: true, - signature: result.signature.map(|s| format!("{}:{}", field_name, s)), - error: None, - }); - } - } - Err(e) => { - log::warn!("Failed to scan multipart field '{}': {}", field_name, e); - parts_failed += 1; - } - } - } - - log::debug!("Multipart scan complete: {} parts scanned, {} failed", parts_scanned, parts_failed); - - // If all parts failed to scan, return an error - if parts_scanned > 0 && parts_failed == parts_scanned { - return Ok(ScanResult { - malware_detected: false, - signature: None, - error: Some(format!("All {} multipart parts failed to scan", parts_failed)), - }); - } - - // No malware detected - Ok(ScanResult { - malware_detected: false, - signature: None, - error: None, - }) - } - - /// Scan HTML form data for malware - pub async fn scan_form_data(&self, form_data: &HashMap) -> Result { - let config = match self.config.read() { - Ok(guard) => guard.clone(), - Err(_) => return Err(anyhow!("Failed to read scanner config")), - }; - - if !config.enabled { - return Ok(ScanResult { - malware_detected: false, - signature: None, - error: None, - }); - } - - // Combine all form values into a single string for scanning - let combined_data = form_data.values() - .map(|v| v.as_str()) - .collect::>() - .join("\n"); - - let mut cursor = Cursor::new(combined_data.as_bytes()); - - // Perform the scan - match scan(&config.clamav_server, &mut cursor, None) { - Ok(result) => { - // Check if malware was detected using the new API - if !result.is_infected { - Ok(ScanResult { - malware_detected: false, - signature: None, - error: None, - }) - } else { - // Extract signature name from detected_infections - let signature = if !result.detected_infections.is_empty() { - Some(result.detected_infections.join(", ")) - } else { - None - }; - - Ok(ScanResult { - malware_detected: true, - signature, - error: None, - }) - } - } - Err(e) => { - log::error!("ClamAV form data scan failed: {}", e); - Ok(ScanResult { - malware_detected: false, - signature: None, - error: Some(format!("Form data scan failed: {}", e)), - }) - } - } - } -} - -// Global content scanner instance -static CONTENT_SCANNER: OnceLock = OnceLock::new(); - -/// Get the global content scanner instance -pub fn get_global_content_scanner() -> Option<&'static ContentScanner> { - CONTENT_SCANNER.get() -} - -/// Set the global content scanner instance -pub fn set_global_content_scanner(scanner: ContentScanner) -> Result<()> { - CONTENT_SCANNER - .set(scanner) - .map_err(|_| anyhow!("Failed to initialize content scanner")) -} - -/// Initialize the global content scanner with default configuration -pub fn init_content_scanner(config: ContentScanningConfig) -> Result<()> { - let scanner = ContentScanner::new(config); - set_global_content_scanner(scanner)?; - log::info!("Content scanner initialized"); - Ok(()) -} - -/// Update the global content scanner configuration -pub fn update_content_scanner_config(config: ContentScanningConfig) -> Result<()> { - if let Some(scanner) = get_global_content_scanner() { - scanner.update_config(config); - log::info!("Content scanner configuration updated"); - Ok(()) - } else { - Err(anyhow!("Content scanner not initialized")) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use hyper::http::request::Builder; - - #[test] - fn test_content_scanner_config_default() { - let config = ContentScanningConfig::default(); - assert!(!config.enabled); - assert_eq!(config.clamav_server, "localhost:3310"); - assert_eq!(config.max_file_size, 10 * 1024 * 1024); - assert!(!config.scan_content_types.is_empty()); - assert!(config.skip_extensions.is_empty()); - } - - #[test] - fn test_should_scan_disabled() { - use std::net::{Ipv4Addr, SocketAddr}; - - let config = ContentScanningConfig { - enabled: false, - ..Default::default() - }; - let scanner = ContentScanner::new(config); - - let req = Builder::new() - .method("POST") - .uri("http://example.com/test") - .body(()) - .unwrap(); - let (req_parts, _) = req.into_parts(); - let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); - - assert!(!scanner.should_scan(&req_parts, b"test content", peer_addr)); - } - - #[test] - fn test_should_scan_content_type_filter() { - use std::net::{Ipv4Addr, SocketAddr}; - - let config = ContentScanningConfig { - enabled: true, - scan_content_types: vec!["text/html".to_string()], - ..Default::default() - }; - let scanner = ContentScanner::new(config); - - let req = Builder::new() - .method("POST") - .uri("http://example.com/test") - .header("content-type", "text/html") - .body(()) - .unwrap(); - let (req_parts, _) = req.into_parts(); - let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); - - assert!(scanner.should_scan(&req_parts, b"test", peer_addr)); - - let req2 = Builder::new() - .method("POST") - .uri("http://example.com/test") - .header("content-type", "application/json") - .body(()) - .unwrap(); - let (req_parts2, _) = req2.into_parts(); - - assert!(!scanner.should_scan(&req_parts2, b"{\"test\": \"data\"}", peer_addr)); - } - - #[test] - fn test_should_scan_file_size_limit() { - use std::net::{Ipv4Addr, SocketAddr}; - - let config = ContentScanningConfig { - enabled: true, - max_file_size: 100, - ..Default::default() - }; - let scanner = ContentScanner::new(config); - - let req = Builder::new() - .method("POST") - .uri("http://example.com/test") - .body(()) - .unwrap(); - let (req_parts, _) = req.into_parts(); - let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); - - let small_content = b"small"; - let large_content = b"x".repeat(200); - - assert!(scanner.should_scan(&req_parts, small_content, peer_addr)); - assert!(!scanner.should_scan(&req_parts, &large_content, peer_addr)); - } - - #[test] - fn test_extract_multipart_boundary() { - // Test with standard format - let ct1 = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"; - assert_eq!( - extract_multipart_boundary(ct1), - Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) - ); - - // Test with quoted boundary - let ct2 = "multipart/form-data; boundary=\"----WebKitFormBoundary7MA4YWxkTrZu0gW\""; - assert_eq!( - extract_multipart_boundary(ct2), - Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) - ); - - // Test with spaces - let ct3 = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW "; - assert_eq!( - extract_multipart_boundary(ct3), - Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) - ); - - // Test with charset and boundary - let ct4 = "multipart/form-data; charset=utf-8; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"; - assert_eq!( - extract_multipart_boundary(ct4), - Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) - ); - - // Test non-multipart content type - let ct5 = "application/json"; - assert_eq!(extract_multipart_boundary(ct5), None); - - // Test missing boundary - let ct6 = "multipart/form-data"; - assert_eq!(extract_multipart_boundary(ct6), None); - - // Test mixed case - let ct7 = "Multipart/Form-Data; Boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"; - assert_eq!( - extract_multipart_boundary(ct7), - Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) - ); - } -} +use anyhow::{Result, anyhow}; +use clamav_tcp::scan; +use serde::{Deserialize, Serialize}; +use std::io::Cursor; +use std::sync::{Arc, RwLock, OnceLock}; +use std::collections::HashMap; +use std::net::SocketAddr; +use hyper::http::request::Parts; +use wirefilter::{ExecutionContext, Scheme, Filter}; +use bytes::Bytes; +use multer::Multipart; + +/// Content scanning configuration +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ContentScanningConfig { + /// Enable or disable content scanning + pub enabled: bool, + /// ClamAV server address (e.g., "localhost:3310") + pub clamav_server: String, + /// Maximum file size to scan in bytes (default: 10MB) + pub max_file_size: usize, + /// Content types to scan (empty means scan all) + pub scan_content_types: Vec, + /// Skip scanning for specific file extensions + pub skip_extensions: Vec, + /// Wirefilter expression to determine when to scan + #[serde(default = "default_scan_expression")] + pub scan_expression: String, +} + +fn default_scan_expression() -> String { + "http.request.method eq \"POST\" or http.request.method eq \"PUT\"".to_string() +} + +impl Default for ContentScanningConfig { + fn default() -> Self { + Self { + enabled: false, + clamav_server: "localhost:3310".to_string(), + max_file_size: 10 * 1024 * 1024, // 10MB + scan_content_types: vec![ + "text/html".to_string(), + "application/x-www-form-urlencoded".to_string(), + "multipart/form-data".to_string(), + "application/json".to_string(), + "text/plain".to_string(), + ], + skip_extensions: vec![], + scan_expression: default_scan_expression(), + } + } +} + +/// Content scanning result +#[derive(Debug, Clone)] +pub struct ScanResult { + /// Whether malware was detected + pub malware_detected: bool, + /// Malware signature name if detected + pub signature: Option, + /// Error message if scanning failed + pub error: Option, +} + +/// Content scanner implementation +pub struct ContentScanner { + config: Arc>, + scheme: Arc, + filter: Arc>>, +} + +/// Extract boundary from Content-Type header for multipart content +pub fn extract_multipart_boundary(content_type: &str) -> Option { + // Content-Type format: multipart/form-data; boundary=----WebKitFormBoundary... + if !content_type.to_lowercase().contains("multipart/") { + return None; + } + + for part in content_type.split(';') { + let trimmed = part.trim(); + let lower = trimmed.to_lowercase(); + if lower.starts_with("boundary=") { + // Find the actual position of "boundary=" in the original string (case-insensitive) + if let Some(eq_pos) = trimmed.to_lowercase().find("boundary=") { + let boundary = trimmed[eq_pos + 9..].trim(); + // Remove quotes if present + let boundary = boundary.trim_matches('"').trim_matches('\''); + return Some(boundary.to_string()); + } + } + } + + None +} + +impl ContentScanner { + /// Create a new content scanner + pub fn new(config: ContentScanningConfig) -> Self { + let scheme = Self::create_scheme(); + let filter = Self::compile_filter(&scheme, &config.scan_expression); + + Self { + config: Arc::new(RwLock::new(config)), + scheme: Arc::new(scheme), + filter: Arc::new(RwLock::new(filter)), + } + } + + /// Create the wirefilter scheme for content scanning + fn create_scheme() -> Scheme { + let builder = wirefilter::Scheme! { + http.request.method: Bytes, + http.request.path: Bytes, + http.request.content_type: Bytes, + http.request.content_length: Int, + }; + builder.build() + } + + /// Compile the scan expression filter + fn compile_filter(scheme: &Scheme, expression: &str) -> Option { + if expression.is_empty() { + return None; + } + + match scheme.parse(expression) { + Ok(ast) => Some(ast.compile()), + Err(e) => { + log::error!("Failed to compile content scanning expression '{}': {}", expression, e); + None + } + } + } + + /// Update scanner configuration + pub fn update_config(&self, config: ContentScanningConfig) { + let new_filter = Self::compile_filter(&self.scheme, &config.scan_expression); + + if let Ok(mut guard) = self.config.write() { + *guard = config; + } + if let Ok(mut guard) = self.filter.write() { + *guard = new_filter; + } + } + + /// Check if content scanning should be performed for this request + pub fn should_scan(&self, req_parts: &Parts, body_bytes: &[u8], _peer_addr: SocketAddr) -> bool { + let config = match self.config.read() { + Ok(guard) => guard.clone(), + Err(_) => return false, + }; + + if !config.enabled { + log::debug!("Content scanning disabled"); + return false; + } + + log::debug!("Checking if should scan request: method={}, path={}, body_size={}", + req_parts.method, req_parts.uri.path(), body_bytes.len()); + + // Check wirefilter expression first + let filter_guard = match self.filter.read() { + Ok(guard) => guard, + Err(_) => return false, + }; + + if let Some(ref filter) = *filter_guard { + let mut ctx = ExecutionContext::new(&self.scheme); + + // Set request fields + let method = req_parts.method.as_str(); + let path = req_parts.uri.path(); + let content_type = req_parts.headers + .get("content-type") + .and_then(|h| h.to_str().ok()) + .unwrap_or(""); + let content_length = body_bytes.len() as i64; + + if ctx.set_field_value(self.scheme.get_field("http.request.method").unwrap(), method).is_err() { + return false; + } + if ctx.set_field_value(self.scheme.get_field("http.request.path").unwrap(), path).is_err() { + return false; + } + if ctx.set_field_value(self.scheme.get_field("http.request.content_type").unwrap(), content_type).is_err() { + return false; + } + if ctx.set_field_value(self.scheme.get_field("http.request.content_length").unwrap(), content_length).is_err() { + return false; + } + + // Execute filter + match filter.execute(&ctx) { + Ok(result) => { + if !result { + log::debug!("Skipping content scan: expression does not match"); + return false; + } else { + log::debug!("Expression matched, proceeding with content scan checks"); + log::debug!("Expression result: {:?}", result); + } + } + Err(e) => { + log::error!("Failed to execute content scanning expression: {}", e); + return false; + } + } + } else { + log::debug!("No scan expression configured, allowing scan"); + } + + // Check if body is too large + if body_bytes.len() > config.max_file_size { + log::debug!("Skipping content scan: body too large ({} bytes)", body_bytes.len()); + return false; + } + + // Check content type + if let Some(content_type) = req_parts.headers.get("content-type") { + if let Ok(content_type_str) = content_type.to_str() { + let content_type_lower = content_type_str.to_lowercase(); + + // If specific content types are configured, only scan those + if !config.scan_content_types.is_empty() { + let should_scan = config.scan_content_types.iter() + .any(|ct| content_type_lower.contains(ct)); + if !should_scan { + log::debug!("Skipping content scan: content type '{}' not in scan list: {:?}", + content_type_str, config.scan_content_types); + return false; + } else { + log::debug!("Content type '{}' matches scan list", content_type_str); + } + } + + // Skip certain content types + if content_type_lower.contains("image/") || + content_type_lower.contains("video/") || + content_type_lower.contains("audio/") { + log::debug!("Skipping content scan: binary content type {}", content_type_str); + return false; + } + } + } + + // Check file extension from URL path + if let Some(path) = req_parts.uri.path().split('/').last() { + if let Some(extension) = std::path::Path::new(path).extension() { + if let Some(ext_str) = extension.to_str() { + let ext_lower = format!(".{}", ext_str.to_lowercase()); + if config.skip_extensions.contains(&ext_lower) { + log::debug!("Skipping content scan: file extension {} in skip list", ext_lower); + return false; + } + } + } + } + + log::debug!("All checks passed, will scan content"); + true + } + + /// Scan content for malware + pub async fn scan_content(&self, body_bytes: &[u8]) -> Result { + let config = match self.config.read() { + Ok(guard) => guard.clone(), + Err(_) => return Err(anyhow!("Failed to read scanner config")), + }; + + if !config.enabled { + return Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }); + } + + self.scan_bytes(&config.clamav_server, body_bytes).await + } + + /// Internal method to scan bytes with ClamAV + async fn scan_bytes(&self, clamav_server: &str, data: &[u8]) -> Result { + // Create a cursor over the body bytes for scanning + let mut cursor = Cursor::new(data); + + // Perform the scan + match scan(clamav_server, &mut cursor, None) { + Ok(result) => { + // Check if malware was detected using the new API + if !result.is_infected { + Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }) + } else { + // Extract signature name from detected_infections + let signature = if !result.detected_infections.is_empty() { + Some(result.detected_infections.join(", ")) + } else { + None + }; + + Ok(ScanResult { + malware_detected: true, + signature, + error: None, + }) + } + } + Err(e) => { + log::error!("ClamAV scan failed: {}", e); + Ok(ScanResult { + malware_detected: false, + signature: None, + error: Some(format!("Scan failed: {}", e)), + }) + } + } + } + + /// Scan multipart content for malware by parsing parts and scanning each individually + pub async fn scan_multipart_content(&self, body_bytes: &[u8], boundary: &str) -> Result { + let config = match self.config.read() { + Ok(guard) => guard.clone(), + Err(_) => return Err(anyhow!("Failed to read scanner config")), + }; + + if !config.enabled { + return Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }); + } + + log::debug!("Parsing multipart body with boundary: {}", boundary); + + // Create a multipart parser + let stream = futures::stream::once(async move { + Result::::Ok(Bytes::copy_from_slice(body_bytes)) + }); + + let mut multipart = Multipart::new(stream, boundary); + + let mut parts_scanned = 0; + let mut parts_failed = 0; + + // Iterate over each part in the multipart body + while let Some(field) = multipart.next_field().await.map_err(|e| anyhow!("Failed to read multipart field: {}", e))? { + let field_name = field.name().unwrap_or("").to_string(); + let field_filename = field.file_name().map(|s| s.to_string()); + let field_content_type = field.content_type().map(|m| m.to_string()); + + log::debug!("Scanning multipart field: name={}, filename={:?}, content_type={:?}", + field_name, field_filename, field_content_type); + + // Read the entire field into bytes + let field_bytes = field.bytes().await.map_err(|e| anyhow!("Failed to read field bytes: {}", e))?; + + // Skip empty fields + if field_bytes.is_empty() { + log::debug!("Skipping empty multipart field: {}", field_name); + continue; + } + + // Check if field size exceeds max_file_size + if field_bytes.len() > config.max_file_size { + log::debug!("Skipping multipart field '{}': size {} exceeds max_file_size {}", + field_name, field_bytes.len(), config.max_file_size); + continue; + } + + parts_scanned += 1; + + // Scan this part + match self.scan_bytes(&config.clamav_server, &field_bytes).await { + Ok(result) => { + if result.malware_detected { + log::info!("Malware detected in multipart field '{}' (filename: {:?}): signature {:?}", + field_name, field_filename, result.signature); + + // Return immediately on first malware detection + return Ok(ScanResult { + malware_detected: true, + signature: result.signature.map(|s| format!("{}:{}", field_name, s)), + error: None, + }); + } + } + Err(e) => { + log::warn!("Failed to scan multipart field '{}': {}", field_name, e); + parts_failed += 1; + } + } + } + + log::debug!("Multipart scan complete: {} parts scanned, {} failed", parts_scanned, parts_failed); + + // If all parts failed to scan, return an error + if parts_scanned > 0 && parts_failed == parts_scanned { + return Ok(ScanResult { + malware_detected: false, + signature: None, + error: Some(format!("All {} multipart parts failed to scan", parts_failed)), + }); + } + + // No malware detected + Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }) + } + + /// Scan HTML form data for malware + pub async fn scan_form_data(&self, form_data: &HashMap) -> Result { + let config = match self.config.read() { + Ok(guard) => guard.clone(), + Err(_) => return Err(anyhow!("Failed to read scanner config")), + }; + + if !config.enabled { + return Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }); + } + + // Combine all form values into a single string for scanning + let combined_data = form_data.values() + .map(|v| v.as_str()) + .collect::>() + .join("\n"); + + let mut cursor = Cursor::new(combined_data.as_bytes()); + + // Perform the scan + match scan(&config.clamav_server, &mut cursor, None) { + Ok(result) => { + // Check if malware was detected using the new API + if !result.is_infected { + Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }) + } else { + // Extract signature name from detected_infections + let signature = if !result.detected_infections.is_empty() { + Some(result.detected_infections.join(", ")) + } else { + None + }; + + Ok(ScanResult { + malware_detected: true, + signature, + error: None, + }) + } + } + Err(e) => { + log::error!("ClamAV form data scan failed: {}", e); + Ok(ScanResult { + malware_detected: false, + signature: None, + error: Some(format!("Form data scan failed: {}", e)), + }) + } + } + } +} + +// Global content scanner instance +static CONTENT_SCANNER: OnceLock = OnceLock::new(); + +/// Get the global content scanner instance +pub fn get_global_content_scanner() -> Option<&'static ContentScanner> { + CONTENT_SCANNER.get() +} + +/// Set the global content scanner instance +pub fn set_global_content_scanner(scanner: ContentScanner) -> Result<()> { + CONTENT_SCANNER + .set(scanner) + .map_err(|_| anyhow!("Failed to initialize content scanner")) +} + +/// Initialize the global content scanner with default configuration +pub fn init_content_scanner(config: ContentScanningConfig) -> Result<()> { + let scanner = ContentScanner::new(config); + set_global_content_scanner(scanner)?; + log::info!("Content scanner initialized"); + Ok(()) +} + +/// Update the global content scanner configuration +pub fn update_content_scanner_config(config: ContentScanningConfig) -> Result<()> { + if let Some(scanner) = get_global_content_scanner() { + scanner.update_config(config); + log::info!("Content scanner configuration updated"); + Ok(()) + } else { + Err(anyhow!("Content scanner not initialized")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use hyper::http::request::Builder; + + #[test] + fn test_content_scanner_config_default() { + let config = ContentScanningConfig::default(); + assert!(!config.enabled); + assert_eq!(config.clamav_server, "localhost:3310"); + assert_eq!(config.max_file_size, 10 * 1024 * 1024); + assert!(!config.scan_content_types.is_empty()); + assert!(config.skip_extensions.is_empty()); + } + + #[test] + fn test_should_scan_disabled() { + use std::net::{Ipv4Addr, SocketAddr}; + + let config = ContentScanningConfig { + enabled: false, + ..Default::default() + }; + let scanner = ContentScanner::new(config); + + let req = Builder::new() + .method("POST") + .uri("http://example.com/test") + .body(()) + .unwrap(); + let (req_parts, _) = req.into_parts(); + let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); + + assert!(!scanner.should_scan(&req_parts, b"test content", peer_addr)); + } + + #[test] + fn test_should_scan_content_type_filter() { + use std::net::{Ipv4Addr, SocketAddr}; + + let config = ContentScanningConfig { + enabled: true, + scan_content_types: vec!["text/html".to_string()], + ..Default::default() + }; + let scanner = ContentScanner::new(config); + + let req = Builder::new() + .method("POST") + .uri("http://example.com/test") + .header("content-type", "text/html") + .body(()) + .unwrap(); + let (req_parts, _) = req.into_parts(); + let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); + + assert!(scanner.should_scan(&req_parts, b"test", peer_addr)); + + let req2 = Builder::new() + .method("POST") + .uri("http://example.com/test") + .header("content-type", "application/json") + .body(()) + .unwrap(); + let (req_parts2, _) = req2.into_parts(); + + assert!(!scanner.should_scan(&req_parts2, b"{\"test\": \"data\"}", peer_addr)); + } + + #[test] + fn test_should_scan_file_size_limit() { + use std::net::{Ipv4Addr, SocketAddr}; + + let config = ContentScanningConfig { + enabled: true, + max_file_size: 100, + ..Default::default() + }; + let scanner = ContentScanner::new(config); + + let req = Builder::new() + .method("POST") + .uri("http://example.com/test") + .body(()) + .unwrap(); + let (req_parts, _) = req.into_parts(); + let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); + + let small_content = b"small"; + let large_content = b"x".repeat(200); + + assert!(scanner.should_scan(&req_parts, small_content, peer_addr)); + assert!(!scanner.should_scan(&req_parts, &large_content, peer_addr)); + } + + #[test] + fn test_extract_multipart_boundary() { + // Test with standard format + let ct1 = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"; + assert_eq!( + extract_multipart_boundary(ct1), + Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) + ); + + // Test with quoted boundary + let ct2 = "multipart/form-data; boundary=\"----WebKitFormBoundary7MA4YWxkTrZu0gW\""; + assert_eq!( + extract_multipart_boundary(ct2), + Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) + ); + + // Test with spaces + let ct3 = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW "; + assert_eq!( + extract_multipart_boundary(ct3), + Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) + ); + + // Test with charset and boundary + let ct4 = "multipart/form-data; charset=utf-8; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"; + assert_eq!( + extract_multipart_boundary(ct4), + Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) + ); + + // Test non-multipart content type + let ct5 = "application/json"; + assert_eq!(extract_multipart_boundary(ct5), None); + + // Test missing boundary + let ct6 = "multipart/form-data"; + assert_eq!(extract_multipart_boundary(ct6), None); + + // Test mixed case + let ct7 = "Multipart/Form-Data; Boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"; + assert_eq!( + extract_multipart_boundary(ct7), + Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) + ); + } +} diff --git a/src/firewall/iptables.rs b/src/firewall/iptables.rs index d08c92f..730ea5d 100644 --- a/src/firewall/iptables.rs +++ b/src/firewall/iptables.rs @@ -1,306 +1,306 @@ -use std::error::Error; -use std::net::{Ipv4Addr, Ipv6Addr}; -use std::sync::atomic::{AtomicBool, Ordering}; - -use super::Firewall; - -/// Chain names for synapse iptables rules -const CHAIN_NAME: &str = "SYNAPSE_BLOCK"; - -static IPTABLES_INITIALIZED: AtomicBool = AtomicBool::new(false); - -/// Iptables-based firewall implementation for when BPF/XDP and nftables are not available -pub struct IptablesFirewall { - ipt4: iptables::IPTables, - ipt6: Option, - initialized: bool, -} - -impl IptablesFirewall { - pub fn new() -> Result> { - let ipt4 = iptables::new(false)?; - let ipt6 = iptables::new(true).ok(); // IPv6 might not be available - - let mut fw = Self { - ipt4, - ipt6, - initialized: false, - }; - fw.initialize()?; - Ok(fw) - } - - /// Check if iptables is available on the system - pub fn is_available() -> bool { - iptables::new(false).is_ok() - } - - /// Initialize the iptables chains - fn initialize(&mut self) -> Result<(), Box> { - if IPTABLES_INITIALIZED.load(Ordering::SeqCst) { - self.initialized = true; - return Ok(()); - } - - if !Self::is_available() { - return Err("iptables not available on system".into()); - } - - log::info!("Initializing iptables firewall (XDP/BPF and nftables fallback)"); - - // Initialize IPv4 chain - self.init_chain(&self.ipt4, "filter")?; - - // Initialize IPv6 chain if available - if let Some(ref ipt6) = self.ipt6 { - if let Err(e) = self.init_chain(ipt6, "filter") { - log::warn!("Failed to initialize ip6tables chain: {}", e); - } - } - - IPTABLES_INITIALIZED.store(true, Ordering::SeqCst); - self.initialized = true; - log::info!("Iptables firewall initialized successfully"); - Ok(()) - } - - fn init_chain(&self, ipt: &iptables::IPTables, table: &str) -> Result<(), Box> { - // Create chain if it doesn't exist - if !ipt.chain_exists(table, CHAIN_NAME)? { - ipt.new_chain(table, CHAIN_NAME)?; - } - - // Add jump rule to INPUT chain if not already present - let jump_rule = format!("-j {}", CHAIN_NAME); - if !ipt.exists(table, "INPUT", &jump_rule)? { - // Insert at the beginning of INPUT chain - ipt.insert(table, "INPUT", &jump_rule, 1)?; - } - - Ok(()) - } - - /// Add an IPv4 address/CIDR to the block chain - fn add_ipv4(&self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { - let addr = if prefixlen == 32 { - ip.to_string() - } else { - format!("{}/{}", ip, prefixlen) - }; - - let rule = format!("-s {} -j DROP", addr); - - // Check if rule already exists - if !self.ipt4.exists("filter", CHAIN_NAME, &rule)? { - self.ipt4.append("filter", CHAIN_NAME, &rule)?; - } - - Ok(()) - } - - /// Remove an IPv4 address/CIDR from the block chain - fn remove_ipv4(&self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { - let addr = if prefixlen == 32 { - ip.to_string() - } else { - format!("{}/{}", ip, prefixlen) - }; - - let rule = format!("-s {} -j DROP", addr); - - // Delete rule if it exists - if self.ipt4.exists("filter", CHAIN_NAME, &rule)? { - self.ipt4.delete("filter", CHAIN_NAME, &rule)?; - } - - Ok(()) - } - - /// Check if an IPv4 address exists in the block chain - fn exists_ipv4(&self, ip: Ipv4Addr) -> Result> { - let rule = format!("-s {} -j DROP", ip); - Ok(self.ipt4.exists("filter", CHAIN_NAME, &rule)?) - } - - /// Add an IPv6 address/CIDR to the block chain - fn add_ipv6(&self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { - let ipt6 = self.ipt6.as_ref().ok_or("ip6tables not available")?; - - let addr = if prefixlen == 128 { - ip.to_string() - } else { - format!("{}/{}", ip, prefixlen) - }; - - let rule = format!("-s {} -j DROP", addr); - - // Check if rule already exists - if !ipt6.exists("filter", CHAIN_NAME, &rule)? { - ipt6.append("filter", CHAIN_NAME, &rule)?; - } - - Ok(()) - } - - /// Remove an IPv6 address/CIDR from the block chain - fn remove_ipv6(&self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { - let ipt6 = self.ipt6.as_ref().ok_or("ip6tables not available")?; - - let addr = if prefixlen == 128 { - ip.to_string() - } else { - format!("{}/{}", ip, prefixlen) - }; - - let rule = format!("-s {} -j DROP", addr); - - // Delete rule if it exists - if ipt6.exists("filter", CHAIN_NAME, &rule)? { - ipt6.delete("filter", CHAIN_NAME, &rule)?; - } - - Ok(()) - } - - /// Check if an IPv6 address exists in the block chain - fn exists_ipv6(&self, ip: Ipv6Addr) -> Result> { - let ipt6 = self.ipt6.as_ref().ok_or("ip6tables not available")?; - let rule = format!("-s {} -j DROP", ip); - Ok(ipt6.exists("filter", CHAIN_NAME, &rule)?) - } - - /// Clean up iptables rules on shutdown - pub fn cleanup(&self) -> Result<(), Box> { - log::info!("Cleaning up iptables firewall rules"); - - // Remove jump rule from INPUT chain - let jump_rule = format!("-j {}", CHAIN_NAME); - if self.ipt4.exists("filter", "INPUT", &jump_rule)? { - let _ = self.ipt4.delete("filter", "INPUT", &jump_rule); - } - - // Flush and delete chain - if self.ipt4.chain_exists("filter", CHAIN_NAME)? { - let _ = self.ipt4.flush_chain("filter", CHAIN_NAME); - let _ = self.ipt4.delete_chain("filter", CHAIN_NAME); - } - - // Do the same for IPv6 - if let Some(ref ipt6) = self.ipt6 { - if ipt6.exists("filter", "INPUT", &jump_rule).unwrap_or(false) { - let _ = ipt6.delete("filter", "INPUT", &jump_rule); - } - if ipt6.chain_exists("filter", CHAIN_NAME).unwrap_or(false) { - let _ = ipt6.flush_chain("filter", CHAIN_NAME); - let _ = ipt6.delete_chain("filter", CHAIN_NAME); - } - } - - IPTABLES_INITIALIZED.store(false, Ordering::SeqCst); - Ok(()) - } -} - -impl Default for IptablesFirewall { - fn default() -> Self { - Self::new().unwrap_or(Self { - ipt4: iptables::new(false).unwrap(), - ipt6: None, - initialized: false, - }) - } -} - -impl Firewall for IptablesFirewall { - fn ban_ip_with_notice(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { - // Iptables doesn't have a separate "notice" concept, just ban - self.add_ipv4(ip, prefixlen)?; - log::debug!("iptables: banned IPv4 {}/{} with notice", ip, prefixlen); - Ok(()) - } - - fn ban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { - self.add_ipv4(ip, prefixlen)?; - log::debug!("iptables: banned IPv4 {}/{}", ip, prefixlen); - Ok(()) - } - - fn unban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { - self.remove_ipv4(ip, prefixlen)?; - log::debug!("iptables: unbanned IPv4 {}/{}", ip, prefixlen); - Ok(()) - } - - fn check_if_notice(&mut self, ip: Ipv4Addr) -> Result> { - // For iptables, we just check if the IP is banned - self.exists_ipv4(ip) - } - - fn ban_ipv6_with_notice(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { - self.add_ipv6(ip, prefixlen)?; - log::debug!("iptables: banned IPv6 {}/{} with notice", ip, prefixlen); - Ok(()) - } - - fn ban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { - self.add_ipv6(ip, prefixlen)?; - log::debug!("iptables: banned IPv6 {}/{}", ip, prefixlen); - Ok(()) - } - - fn unban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { - self.remove_ipv6(ip, prefixlen)?; - log::debug!("iptables: unbanned IPv6 {}/{}", ip, prefixlen); - Ok(()) - } - - fn check_if_notice_ipv6(&mut self, ip: Ipv6Addr) -> Result> { - self.exists_ipv6(ip) - } - - // TCP fingerprint blocking is not supported via iptables (requires BPF) - fn block_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint blocking not supported in iptables fallback mode (fingerprint: {})", fingerprint); - Ok(()) - } - - fn unblock_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint unblocking not supported in iptables fallback mode (fingerprint: {})", fingerprint); - Ok(()) - } - - fn block_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint blocking (IPv6) not supported in iptables fallback mode (fingerprint: {})", fingerprint); - Ok(()) - } - - fn unblock_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint unblocking (IPv6) not supported in iptables fallback mode (fingerprint: {})", fingerprint); - Ok(()) - } - - fn is_tcp_fingerprint_blocked(&self, _fingerprint: &str) -> Result> { - Ok(false) - } - - fn is_tcp_fingerprint_blocked_v6(&self, _fingerprint: &str) -> Result> { - Ok(false) - } -} - -impl Drop for IptablesFirewall { - fn drop(&mut self) { - // Don't cleanup on drop - rules should persist - // Call cleanup() explicitly if needed - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_iptables_available() { - let _ = IptablesFirewall::is_available(); - } -} +use std::error::Error; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::sync::atomic::{AtomicBool, Ordering}; + +use super::Firewall; + +/// Chain names for synapse iptables rules +const CHAIN_NAME: &str = "SYNAPSE_BLOCK"; + +static IPTABLES_INITIALIZED: AtomicBool = AtomicBool::new(false); + +/// Iptables-based firewall implementation for when BPF/XDP and nftables are not available +pub struct IptablesFirewall { + ipt4: iptables::IPTables, + ipt6: Option, + initialized: bool, +} + +impl IptablesFirewall { + pub fn new() -> Result> { + let ipt4 = iptables::new(false)?; + let ipt6 = iptables::new(true).ok(); // IPv6 might not be available + + let mut fw = Self { + ipt4, + ipt6, + initialized: false, + }; + fw.initialize()?; + Ok(fw) + } + + /// Check if iptables is available on the system + pub fn is_available() -> bool { + iptables::new(false).is_ok() + } + + /// Initialize the iptables chains + fn initialize(&mut self) -> Result<(), Box> { + if IPTABLES_INITIALIZED.load(Ordering::SeqCst) { + self.initialized = true; + return Ok(()); + } + + if !Self::is_available() { + return Err("iptables not available on system".into()); + } + + log::info!("Initializing iptables firewall (XDP/BPF and nftables fallback)"); + + // Initialize IPv4 chain + self.init_chain(&self.ipt4, "filter")?; + + // Initialize IPv6 chain if available + if let Some(ref ipt6) = self.ipt6 { + if let Err(e) = self.init_chain(ipt6, "filter") { + log::warn!("Failed to initialize ip6tables chain: {}", e); + } + } + + IPTABLES_INITIALIZED.store(true, Ordering::SeqCst); + self.initialized = true; + log::info!("Iptables firewall initialized successfully"); + Ok(()) + } + + fn init_chain(&self, ipt: &iptables::IPTables, table: &str) -> Result<(), Box> { + // Create chain if it doesn't exist + if !ipt.chain_exists(table, CHAIN_NAME)? { + ipt.new_chain(table, CHAIN_NAME)?; + } + + // Add jump rule to INPUT chain if not already present + let jump_rule = format!("-j {}", CHAIN_NAME); + if !ipt.exists(table, "INPUT", &jump_rule)? { + // Insert at the beginning of INPUT chain + ipt.insert(table, "INPUT", &jump_rule, 1)?; + } + + Ok(()) + } + + /// Add an IPv4 address/CIDR to the block chain + fn add_ipv4(&self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { + let addr = if prefixlen == 32 { + ip.to_string() + } else { + format!("{}/{}", ip, prefixlen) + }; + + let rule = format!("-s {} -j DROP", addr); + + // Check if rule already exists + if !self.ipt4.exists("filter", CHAIN_NAME, &rule)? { + self.ipt4.append("filter", CHAIN_NAME, &rule)?; + } + + Ok(()) + } + + /// Remove an IPv4 address/CIDR from the block chain + fn remove_ipv4(&self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { + let addr = if prefixlen == 32 { + ip.to_string() + } else { + format!("{}/{}", ip, prefixlen) + }; + + let rule = format!("-s {} -j DROP", addr); + + // Delete rule if it exists + if self.ipt4.exists("filter", CHAIN_NAME, &rule)? { + self.ipt4.delete("filter", CHAIN_NAME, &rule)?; + } + + Ok(()) + } + + /// Check if an IPv4 address exists in the block chain + fn exists_ipv4(&self, ip: Ipv4Addr) -> Result> { + let rule = format!("-s {} -j DROP", ip); + Ok(self.ipt4.exists("filter", CHAIN_NAME, &rule)?) + } + + /// Add an IPv6 address/CIDR to the block chain + fn add_ipv6(&self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { + let ipt6 = self.ipt6.as_ref().ok_or("ip6tables not available")?; + + let addr = if prefixlen == 128 { + ip.to_string() + } else { + format!("{}/{}", ip, prefixlen) + }; + + let rule = format!("-s {} -j DROP", addr); + + // Check if rule already exists + if !ipt6.exists("filter", CHAIN_NAME, &rule)? { + ipt6.append("filter", CHAIN_NAME, &rule)?; + } + + Ok(()) + } + + /// Remove an IPv6 address/CIDR from the block chain + fn remove_ipv6(&self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { + let ipt6 = self.ipt6.as_ref().ok_or("ip6tables not available")?; + + let addr = if prefixlen == 128 { + ip.to_string() + } else { + format!("{}/{}", ip, prefixlen) + }; + + let rule = format!("-s {} -j DROP", addr); + + // Delete rule if it exists + if ipt6.exists("filter", CHAIN_NAME, &rule)? { + ipt6.delete("filter", CHAIN_NAME, &rule)?; + } + + Ok(()) + } + + /// Check if an IPv6 address exists in the block chain + fn exists_ipv6(&self, ip: Ipv6Addr) -> Result> { + let ipt6 = self.ipt6.as_ref().ok_or("ip6tables not available")?; + let rule = format!("-s {} -j DROP", ip); + Ok(ipt6.exists("filter", CHAIN_NAME, &rule)?) + } + + /// Clean up iptables rules on shutdown + pub fn cleanup(&self) -> Result<(), Box> { + log::info!("Cleaning up iptables firewall rules"); + + // Remove jump rule from INPUT chain + let jump_rule = format!("-j {}", CHAIN_NAME); + if self.ipt4.exists("filter", "INPUT", &jump_rule)? { + let _ = self.ipt4.delete("filter", "INPUT", &jump_rule); + } + + // Flush and delete chain + if self.ipt4.chain_exists("filter", CHAIN_NAME)? { + let _ = self.ipt4.flush_chain("filter", CHAIN_NAME); + let _ = self.ipt4.delete_chain("filter", CHAIN_NAME); + } + + // Do the same for IPv6 + if let Some(ref ipt6) = self.ipt6 { + if ipt6.exists("filter", "INPUT", &jump_rule).unwrap_or(false) { + let _ = ipt6.delete("filter", "INPUT", &jump_rule); + } + if ipt6.chain_exists("filter", CHAIN_NAME).unwrap_or(false) { + let _ = ipt6.flush_chain("filter", CHAIN_NAME); + let _ = ipt6.delete_chain("filter", CHAIN_NAME); + } + } + + IPTABLES_INITIALIZED.store(false, Ordering::SeqCst); + Ok(()) + } +} + +impl Default for IptablesFirewall { + fn default() -> Self { + Self::new().unwrap_or(Self { + ipt4: iptables::new(false).unwrap(), + ipt6: None, + initialized: false, + }) + } +} + +impl Firewall for IptablesFirewall { + fn ban_ip_with_notice(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { + // Iptables doesn't have a separate "notice" concept, just ban + self.add_ipv4(ip, prefixlen)?; + log::debug!("iptables: banned IPv4 {}/{} with notice", ip, prefixlen); + Ok(()) + } + + fn ban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { + self.add_ipv4(ip, prefixlen)?; + log::debug!("iptables: banned IPv4 {}/{}", ip, prefixlen); + Ok(()) + } + + fn unban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { + self.remove_ipv4(ip, prefixlen)?; + log::debug!("iptables: unbanned IPv4 {}/{}", ip, prefixlen); + Ok(()) + } + + fn check_if_notice(&mut self, ip: Ipv4Addr) -> Result> { + // For iptables, we just check if the IP is banned + self.exists_ipv4(ip) + } + + fn ban_ipv6_with_notice(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { + self.add_ipv6(ip, prefixlen)?; + log::debug!("iptables: banned IPv6 {}/{} with notice", ip, prefixlen); + Ok(()) + } + + fn ban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { + self.add_ipv6(ip, prefixlen)?; + log::debug!("iptables: banned IPv6 {}/{}", ip, prefixlen); + Ok(()) + } + + fn unban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { + self.remove_ipv6(ip, prefixlen)?; + log::debug!("iptables: unbanned IPv6 {}/{}", ip, prefixlen); + Ok(()) + } + + fn check_if_notice_ipv6(&mut self, ip: Ipv6Addr) -> Result> { + self.exists_ipv6(ip) + } + + // TCP fingerprint blocking is not supported via iptables (requires BPF) + fn block_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { + log::warn!("TCP fingerprint blocking not supported in iptables fallback mode (fingerprint: {})", fingerprint); + Ok(()) + } + + fn unblock_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { + log::warn!("TCP fingerprint unblocking not supported in iptables fallback mode (fingerprint: {})", fingerprint); + Ok(()) + } + + fn block_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { + log::warn!("TCP fingerprint blocking (IPv6) not supported in iptables fallback mode (fingerprint: {})", fingerprint); + Ok(()) + } + + fn unblock_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { + log::warn!("TCP fingerprint unblocking (IPv6) not supported in iptables fallback mode (fingerprint: {})", fingerprint); + Ok(()) + } + + fn is_tcp_fingerprint_blocked(&self, _fingerprint: &str) -> Result> { + Ok(false) + } + + fn is_tcp_fingerprint_blocked_v6(&self, _fingerprint: &str) -> Result> { + Ok(false) + } +} + +impl Drop for IptablesFirewall { + fn drop(&mut self) { + // Don't cleanup on drop - rules should persist + // Call cleanup() explicitly if needed + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_iptables_available() { + let _ = IptablesFirewall::is_available(); + } +} diff --git a/src/firewall/mod.rs b/src/firewall/mod.rs index e81a715..85ba42a 100644 --- a/src/firewall/mod.rs +++ b/src/firewall/mod.rs @@ -1,298 +1,298 @@ -use std::{error::Error, net::{Ipv4Addr, Ipv6Addr}}; - -use libbpf_rs::{MapCore, MapFlags}; -use serde::{Deserialize, Serialize}; - -use crate::utils::bpf_utils; - -pub mod nftables; -pub mod iptables; -pub use nftables::NftablesFirewall; -pub use iptables::IptablesFirewall; - -/// Enum to represent the active firewall backend -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum FirewallBackend { - /// XDP/BPF-based firewall (high performance) - Xdp, - /// Nftables-based firewall (fallback when BPF unavailable) - Nftables, - /// Iptables-based firewall (legacy fallback) - Iptables, - /// No firewall (userland enforcement only) - None, -} - -impl std::fmt::Display for FirewallBackend { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - FirewallBackend::Xdp => write!(f, "XDP/BPF"), - FirewallBackend::Nftables => write!(f, "nftables"), - FirewallBackend::Iptables => write!(f, "iptables"), - FirewallBackend::None => write!(f, "none (userland)"), - } - } -} - -/// Configuration option for forcing a specific firewall backend -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)] -#[serde(rename_all = "lowercase")] -pub enum FirewallMode { - /// Automatically select the best available backend (XDP > nftables > iptables > none) - #[default] - Auto, - /// Force XDP/BPF backend (will fail if not available) - Xdp, - /// Force nftables backend - Nftables, - /// Force iptables backend - Iptables, - /// Disable kernel firewall, userland enforcement only - None, -} - -impl std::fmt::Display for FirewallMode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - FirewallMode::Auto => write!(f, "auto"), - FirewallMode::Xdp => write!(f, "xdp"), - FirewallMode::Nftables => write!(f, "nftables"), - FirewallMode::Iptables => write!(f, "iptables"), - FirewallMode::None => write!(f, "none"), - } - } -} - -pub trait Firewall { - fn ban_ip_with_notice(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box>; - fn ban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box>; - fn unban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box>; - fn check_if_notice(&mut self, ip: Ipv4Addr) -> Result>; - - // IPv6 methods - fn ban_ipv6_with_notice(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box>; - fn ban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box>; - fn unban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box>; - fn check_if_notice_ipv6(&mut self, ip: Ipv6Addr) -> Result>; - - // TCP fingerprint blocking methods - fn block_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box>; - fn unblock_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box>; - fn block_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box>; - fn unblock_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box>; - fn is_tcp_fingerprint_blocked(&self, fingerprint: &str) -> Result>; - fn is_tcp_fingerprint_blocked_v6(&self, fingerprint: &str) -> Result>; -} - -pub struct SYNAPSEFirewall<'a> { - skel: &'a crate::bpf::FilterSkel<'a>, -} - -impl<'a> SYNAPSEFirewall<'a> { - pub fn new(skel: &'a crate::bpf::FilterSkel<'a>) -> Self { - Self { skel } - } - - /// Convert a fingerprint string to a 14-byte array for BPF map - /// Fingerprint format is typically: "TTL:MSS:Window:Scale" (e.g., "064:1460:65535:7") - /// This is truncated/padded to exactly 14 bytes - fn fingerprint_to_bytes(fingerprint: &str) -> Result<[u8; 14], Box> { - let mut bytes = [0u8; 14]; - let fp_bytes = fingerprint.as_bytes(); - - // Copy up to 14 bytes - let copy_len = std::cmp::min(fp_bytes.len(), 14); - bytes[..copy_len].copy_from_slice(&fp_bytes[..copy_len]); - - Ok(bytes) - } -} - -impl<'a> Firewall for SYNAPSEFirewall<'a> { - fn ban_ip_with_notice(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { - let ip_bytes = &bpf_utils::convert_ip_into_bpf_map_key_bytes(ip, prefixlen); - let flag = 1_u8; - - self.skel - .maps - .recently_banned_ips - .update(ip_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; - - Ok(()) - } - - fn ban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { - let ip_bytes = &bpf_utils::convert_ip_into_bpf_map_key_bytes(ip, prefixlen); - let flag = 1_u8; - - self.skel - .maps - .banned_ips - .update(ip_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; - - Ok(()) - } - - fn check_if_notice(&mut self, ip: Ipv4Addr) -> Result> { - let ip_bytes = &bpf_utils::convert_ip_into_bpf_map_key_bytes(ip, 32); - - if let Some(val) = self - .skel - .maps - .recently_banned_ips - .lookup(ip_bytes, MapFlags::ANY)? - { - if val[0] == 1_u8 { - return Ok(true); - } else { - return Ok(false); - } - } - - Ok(true) - } - - fn unban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { - let ip_bytes = &bpf_utils::convert_ip_into_bpf_map_key_bytes(ip, prefixlen); - - self.skel.maps.banned_ips.delete(ip_bytes)?; - - Ok(()) - } - - // IPv6 implementations - fn ban_ipv6_with_notice(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { - let ip_bytes = &bpf_utils::convert_ipv6_into_bpf_map_key_bytes(ip, prefixlen); - let flag = 1_u8; - - self.skel - .maps - .recently_banned_ips_v6 - .update(ip_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; - - Ok(()) - } - - fn ban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { - let ip_bytes = &bpf_utils::convert_ipv6_into_bpf_map_key_bytes(ip, prefixlen); - let flag = 1_u8; - - self.skel - .maps - .banned_ips_v6 - .update(ip_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; - - Ok(()) - } - - fn check_if_notice_ipv6(&mut self, ip: Ipv6Addr) -> Result> { - let ip_bytes = &bpf_utils::convert_ipv6_into_bpf_map_key_bytes(ip, 128); - - if let Some(val) = self - .skel - .maps - .recently_banned_ips_v6 - .lookup(ip_bytes, MapFlags::ANY)? - { - if val[0] == 1_u8 { - return Ok(true); - } else { - return Ok(false); - } - } - - Ok(true) - } - - fn unban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { - let ip_bytes = &bpf_utils::convert_ipv6_into_bpf_map_key_bytes(ip, prefixlen); - - self.skel.maps.banned_ips_v6.delete(ip_bytes)?; - - Ok(()) - } - - fn block_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { - let fp_bytes = Self::fingerprint_to_bytes(fingerprint)?; - let flag = 1_u8; - - self.skel - .maps - .blocked_tcp_fingerprints - .update(&fp_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; - - log::info!("Blocked TCP fingerprint (IPv4): {}", fingerprint); - Ok(()) - } - - fn unblock_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { - let fp_bytes = Self::fingerprint_to_bytes(fingerprint)?; - - self.skel - .maps - .blocked_tcp_fingerprints - .delete(&fp_bytes)?; - - log::info!("Unblocked TCP fingerprint (IPv4): {}", fingerprint); - Ok(()) - } - - fn block_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { - let fp_bytes = Self::fingerprint_to_bytes(fingerprint)?; - let flag = 1_u8; - - self.skel - .maps - .blocked_tcp_fingerprints_v6 - .update(&fp_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; - - log::info!("Blocked TCP fingerprint (IPv6): {}", fingerprint); - Ok(()) - } - - fn unblock_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { - let fp_bytes = Self::fingerprint_to_bytes(fingerprint)?; - - self.skel - .maps - .blocked_tcp_fingerprints_v6 - .delete(&fp_bytes)?; - - log::info!("Unblocked TCP fingerprint (IPv6): {}", fingerprint); - Ok(()) - } - - fn is_tcp_fingerprint_blocked(&self, fingerprint: &str) -> Result> { - let fp_bytes = Self::fingerprint_to_bytes(fingerprint)?; - - if let Some(val) = self - .skel - .maps - .blocked_tcp_fingerprints - .lookup(&fp_bytes, MapFlags::ANY)? - { - if val[0] == 1_u8 { - return Ok(true); - } - } - - Ok(false) - } - - fn is_tcp_fingerprint_blocked_v6(&self, fingerprint: &str) -> Result> { - let fp_bytes = Self::fingerprint_to_bytes(fingerprint)?; - - if let Some(val) = self - .skel - .maps - .blocked_tcp_fingerprints_v6 - .lookup(&fp_bytes, MapFlags::ANY)? - { - if val[0] == 1_u8 { - return Ok(true); - } - } - - Ok(false) - } -} +use std::{error::Error, net::{Ipv4Addr, Ipv6Addr}}; + +use libbpf_rs::{MapCore, MapFlags}; +use serde::{Deserialize, Serialize}; + +use crate::utils::bpf_utils; + +pub mod nftables; +pub mod iptables; +pub use nftables::NftablesFirewall; +pub use iptables::IptablesFirewall; + +/// Enum to represent the active firewall backend +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FirewallBackend { + /// XDP/BPF-based firewall (high performance) + Xdp, + /// Nftables-based firewall (fallback when BPF unavailable) + Nftables, + /// Iptables-based firewall (legacy fallback) + Iptables, + /// No firewall (userland enforcement only) + None, +} + +impl std::fmt::Display for FirewallBackend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FirewallBackend::Xdp => write!(f, "XDP/BPF"), + FirewallBackend::Nftables => write!(f, "nftables"), + FirewallBackend::Iptables => write!(f, "iptables"), + FirewallBackend::None => write!(f, "none (userland)"), + } + } +} + +/// Configuration option for forcing a specific firewall backend +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum FirewallMode { + /// Automatically select the best available backend (XDP > nftables > iptables > none) + #[default] + Auto, + /// Force XDP/BPF backend (will fail if not available) + Xdp, + /// Force nftables backend + Nftables, + /// Force iptables backend + Iptables, + /// Disable kernel firewall, userland enforcement only + None, +} + +impl std::fmt::Display for FirewallMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FirewallMode::Auto => write!(f, "auto"), + FirewallMode::Xdp => write!(f, "xdp"), + FirewallMode::Nftables => write!(f, "nftables"), + FirewallMode::Iptables => write!(f, "iptables"), + FirewallMode::None => write!(f, "none"), + } + } +} + +pub trait Firewall { + fn ban_ip_with_notice(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box>; + fn ban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box>; + fn unban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box>; + fn check_if_notice(&mut self, ip: Ipv4Addr) -> Result>; + + // IPv6 methods + fn ban_ipv6_with_notice(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box>; + fn ban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box>; + fn unban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box>; + fn check_if_notice_ipv6(&mut self, ip: Ipv6Addr) -> Result>; + + // TCP fingerprint blocking methods + fn block_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box>; + fn unblock_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box>; + fn block_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box>; + fn unblock_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box>; + fn is_tcp_fingerprint_blocked(&self, fingerprint: &str) -> Result>; + fn is_tcp_fingerprint_blocked_v6(&self, fingerprint: &str) -> Result>; +} + +pub struct SYNAPSEFirewall<'a> { + skel: &'a crate::bpf::FilterSkel<'a>, +} + +impl<'a> SYNAPSEFirewall<'a> { + pub fn new(skel: &'a crate::bpf::FilterSkel<'a>) -> Self { + Self { skel } + } + + /// Convert a fingerprint string to a 14-byte array for BPF map + /// Fingerprint format is typically: "TTL:MSS:Window:Scale" (e.g., "064:1460:65535:7") + /// This is truncated/padded to exactly 14 bytes + fn fingerprint_to_bytes(fingerprint: &str) -> Result<[u8; 14], Box> { + let mut bytes = [0u8; 14]; + let fp_bytes = fingerprint.as_bytes(); + + // Copy up to 14 bytes + let copy_len = std::cmp::min(fp_bytes.len(), 14); + bytes[..copy_len].copy_from_slice(&fp_bytes[..copy_len]); + + Ok(bytes) + } +} + +impl<'a> Firewall for SYNAPSEFirewall<'a> { + fn ban_ip_with_notice(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { + let ip_bytes = &bpf_utils::convert_ip_into_bpf_map_key_bytes(ip, prefixlen); + let flag = 1_u8; + + self.skel + .maps + .recently_banned_ips + .update(ip_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; + + Ok(()) + } + + fn ban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { + let ip_bytes = &bpf_utils::convert_ip_into_bpf_map_key_bytes(ip, prefixlen); + let flag = 1_u8; + + self.skel + .maps + .banned_ips + .update(ip_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; + + Ok(()) + } + + fn check_if_notice(&mut self, ip: Ipv4Addr) -> Result> { + let ip_bytes = &bpf_utils::convert_ip_into_bpf_map_key_bytes(ip, 32); + + if let Some(val) = self + .skel + .maps + .recently_banned_ips + .lookup(ip_bytes, MapFlags::ANY)? + { + if val[0] == 1_u8 { + return Ok(true); + } else { + return Ok(false); + } + } + + Ok(true) + } + + fn unban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { + let ip_bytes = &bpf_utils::convert_ip_into_bpf_map_key_bytes(ip, prefixlen); + + self.skel.maps.banned_ips.delete(ip_bytes)?; + + Ok(()) + } + + // IPv6 implementations + fn ban_ipv6_with_notice(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { + let ip_bytes = &bpf_utils::convert_ipv6_into_bpf_map_key_bytes(ip, prefixlen); + let flag = 1_u8; + + self.skel + .maps + .recently_banned_ips_v6 + .update(ip_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; + + Ok(()) + } + + fn ban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { + let ip_bytes = &bpf_utils::convert_ipv6_into_bpf_map_key_bytes(ip, prefixlen); + let flag = 1_u8; + + self.skel + .maps + .banned_ips_v6 + .update(ip_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; + + Ok(()) + } + + fn check_if_notice_ipv6(&mut self, ip: Ipv6Addr) -> Result> { + let ip_bytes = &bpf_utils::convert_ipv6_into_bpf_map_key_bytes(ip, 128); + + if let Some(val) = self + .skel + .maps + .recently_banned_ips_v6 + .lookup(ip_bytes, MapFlags::ANY)? + { + if val[0] == 1_u8 { + return Ok(true); + } else { + return Ok(false); + } + } + + Ok(true) + } + + fn unban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { + let ip_bytes = &bpf_utils::convert_ipv6_into_bpf_map_key_bytes(ip, prefixlen); + + self.skel.maps.banned_ips_v6.delete(ip_bytes)?; + + Ok(()) + } + + fn block_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { + let fp_bytes = Self::fingerprint_to_bytes(fingerprint)?; + let flag = 1_u8; + + self.skel + .maps + .blocked_tcp_fingerprints + .update(&fp_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; + + log::info!("Blocked TCP fingerprint (IPv4): {}", fingerprint); + Ok(()) + } + + fn unblock_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { + let fp_bytes = Self::fingerprint_to_bytes(fingerprint)?; + + self.skel + .maps + .blocked_tcp_fingerprints + .delete(&fp_bytes)?; + + log::info!("Unblocked TCP fingerprint (IPv4): {}", fingerprint); + Ok(()) + } + + fn block_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { + let fp_bytes = Self::fingerprint_to_bytes(fingerprint)?; + let flag = 1_u8; + + self.skel + .maps + .blocked_tcp_fingerprints_v6 + .update(&fp_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; + + log::info!("Blocked TCP fingerprint (IPv6): {}", fingerprint); + Ok(()) + } + + fn unblock_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { + let fp_bytes = Self::fingerprint_to_bytes(fingerprint)?; + + self.skel + .maps + .blocked_tcp_fingerprints_v6 + .delete(&fp_bytes)?; + + log::info!("Unblocked TCP fingerprint (IPv6): {}", fingerprint); + Ok(()) + } + + fn is_tcp_fingerprint_blocked(&self, fingerprint: &str) -> Result> { + let fp_bytes = Self::fingerprint_to_bytes(fingerprint)?; + + if let Some(val) = self + .skel + .maps + .blocked_tcp_fingerprints + .lookup(&fp_bytes, MapFlags::ANY)? + { + if val[0] == 1_u8 { + return Ok(true); + } + } + + Ok(false) + } + + fn is_tcp_fingerprint_blocked_v6(&self, fingerprint: &str) -> Result> { + let fp_bytes = Self::fingerprint_to_bytes(fingerprint)?; + + if let Some(val) = self + .skel + .maps + .blocked_tcp_fingerprints_v6 + .lookup(&fp_bytes, MapFlags::ANY)? + { + if val[0] == 1_u8 { + return Ok(true); + } + } + + Ok(false) + } +} diff --git a/src/firewall/nftables.rs b/src/firewall/nftables.rs index f2473e6..39c0406 100644 --- a/src/firewall/nftables.rs +++ b/src/firewall/nftables.rs @@ -1,405 +1,405 @@ -use std::error::Error; -use std::net::{Ipv4Addr, Ipv6Addr}; -use std::process::Command; -use std::sync::atomic::{AtomicBool, Ordering}; - -use super::Firewall; - -/// Path to nft binary - try common locations -const NFT_PATHS: &[&str] = &["/usr/bin/nft", "/usr/sbin/nft", "/sbin/nft", "nft"]; - -/// Table and chain names for synapse nftables rules -const NFT_TABLE_NAME: &str = "synapse"; -const NFT_CHAIN_INPUT: &str = "synapse_input"; -const NFT_SET_BANNED_IPV4: &str = "banned_ips_v4"; -const NFT_SET_BANNED_IPV6: &str = "banned_ips_v6"; -const NFT_SET_NOTICE_IPV4: &str = "notice_ips_v4"; -const NFT_SET_NOTICE_IPV6: &str = "notice_ips_v6"; - -/// Find the nft binary path -fn find_nft_path() -> Option<&'static str> { - for path in NFT_PATHS { - if Command::new(path) - .arg("--version") - .output() - .map(|o| o.status.success()) - .unwrap_or(false) - { - return Some(path); - } - } - None -} - -/// Get the nft command - caches the path after first lookup -fn nft_cmd() -> Command { - static NFT_PATH: std::sync::OnceLock> = std::sync::OnceLock::new(); - let path = NFT_PATH.get_or_init(find_nft_path); - Command::new(path.unwrap_or("nft")) -} - -static NFTABLES_INITIALIZED: AtomicBool = AtomicBool::new(false); - -/// Nftables-based firewall implementation for when BPF/XDP is not available -pub struct NftablesFirewall { - initialized: bool, -} - -impl NftablesFirewall { - pub fn new() -> Result> { - let mut fw = Self { initialized: false }; - fw.initialize()?; - Ok(fw) - } - - /// Check if nftables is available on the system - pub fn is_available() -> bool { - find_nft_path().is_some() - } - - /// Initialize the nftables table, chains, and sets - fn initialize(&mut self) -> Result<(), Box> { - if NFTABLES_INITIALIZED.load(Ordering::SeqCst) { - self.initialized = true; - return Ok(()); - } - - if !Self::is_available() { - return Err("nftables (nft) command not found on system".into()); - } - - log::info!("Initializing nftables firewall (XDP/BPF fallback)"); - - // Use nft command directly for more reliable initialization - self.init_with_nft_command()?; - - NFTABLES_INITIALIZED.store(true, Ordering::SeqCst); - self.initialized = true; - log::info!("Nftables firewall initialized successfully"); - Ok(()) - } - - /// Initialize using nft command directly (more reliable) - fn init_with_nft_command(&self) -> Result<(), Box> { - // Use nft -f with heredoc-style input for reliable parsing - let nft_script = format!(r#" -table inet {table} {{ - set {set_v4} {{ - type ipv4_addr - flags interval - }} - set {set_v6} {{ - type ipv6_addr - flags interval - }} - set {notice_v4} {{ - type ipv4_addr - flags interval - }} - set {notice_v6} {{ - type ipv6_addr - flags interval - }} - chain {chain} {{ - type filter hook input priority -100; policy accept; - ip saddr @{set_v4} drop - ip6 saddr @{set_v6} drop - }} -}} -"#, - table = NFT_TABLE_NAME, - set_v4 = NFT_SET_BANNED_IPV4, - set_v6 = NFT_SET_BANNED_IPV6, - notice_v4 = NFT_SET_NOTICE_IPV4, - notice_v6 = NFT_SET_NOTICE_IPV6, - chain = NFT_CHAIN_INPUT, - ); - - // First try to delete existing table (ignore errors) - let _ = nft_cmd() - .args(["delete", "table", "inet", NFT_TABLE_NAME]) - .output(); - - // Create table with all sets and chains - let output = nft_cmd() - .args(["-f", "-"]) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) - .spawn() - .and_then(|mut child| { - use std::io::Write; - if let Some(stdin) = child.stdin.as_mut() { - stdin.write_all(nft_script.as_bytes())?; - } - child.wait_with_output() - })?; - - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr); - - // Check for common error patterns and provide user-friendly messages - if stderr.contains("Operation not supported") { - log::debug!("nftables kernel support not available: {}", stderr); - return Err("nftables kernel support not available (Operation not supported) - this is common in containers or systems without nf_tables kernel module".into()); - } else if stderr.contains("Permission denied") { - return Err("nftables permission denied - ensure synapse runs as root with CAP_NET_ADMIN".into()); - } else if stderr.contains("No such file or directory") { - return Err("nftables failed - nft command or required files not found".into()); - } else { - log::debug!("nftables initialization failed: {}", stderr); - return Err(format!("nftables initialization failed: {}", stderr.lines().next().unwrap_or("unknown error")).into()); - } - } - - Ok(()) - } - - /// Add an IPv4 address/CIDR to a set - fn add_to_set_v4(&self, set_name: &str, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { - let addr = if prefixlen == 32 { - ip.to_string() - } else { - format!("{}/{}", ip, prefixlen) - }; - - let element = format!("{{ {} }}", addr); - let output = nft_cmd() - .args(["add", "element", "inet", NFT_TABLE_NAME, set_name, &element]) - .output()?; - - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr); - let stdout = String::from_utf8_lossy(&output.stdout); - // Ignore "already exists" errors - if !stderr.contains("exists") && !stdout.contains("exists") { - let error_msg = if stderr.is_empty() && stdout.is_empty() { - format!("exit code: {:?}", output.status.code()) - } else if stderr.is_empty() { - stdout.to_string() - } else { - stderr.to_string() - }; - return Err(format!("Failed to add {} to {}: {}", addr, set_name, error_msg).into()); - } - } - Ok(()) - } - - /// Remove an IPv4 address/CIDR from a set - fn remove_from_set_v4(&self, set_name: &str, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { - let addr = if prefixlen == 32 { - ip.to_string() - } else { - format!("{}/{}", ip, prefixlen) - }; - - let output = nft_cmd() - .args(["delete", "element", "inet", NFT_TABLE_NAME, set_name, &format!("{{ {} }}", addr)]) - .output()?; - - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr); - // Ignore "not found" errors - if !stderr.contains("No such") && !stderr.contains("does not exist") { - return Err(format!("Failed to remove {} from {}: {}", addr, set_name, stderr).into()); - } - } - Ok(()) - } - - /// Check if an IPv4 address exists in a set - fn exists_in_set_v4(&self, set_name: &str, ip: Ipv4Addr) -> Result> { - let output = nft_cmd() - .args(["list", "set", "inet", NFT_TABLE_NAME, set_name]) - .output()?; - - if output.status.success() { - let stdout = String::from_utf8_lossy(&output.stdout); - Ok(stdout.contains(&ip.to_string())) - } else { - Ok(false) - } - } - - /// Add an IPv6 address/CIDR to a set - fn add_to_set_v6(&self, set_name: &str, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { - let addr = if prefixlen == 128 { - ip.to_string() - } else { - format!("{}/{}", ip, prefixlen) - }; - - let element = format!("{{ {} }}", addr); - let output = nft_cmd() - .args(["add", "element", "inet", NFT_TABLE_NAME, set_name, &element]) - .output()?; - - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr); - let stdout = String::from_utf8_lossy(&output.stdout); - if !stderr.contains("exists") && !stdout.contains("exists") { - let error_msg = if stderr.is_empty() && stdout.is_empty() { - format!("exit code: {:?}", output.status.code()) - } else if stderr.is_empty() { - stdout.to_string() - } else { - stderr.to_string() - }; - return Err(format!("Failed to add {} to {}: {}", addr, set_name, error_msg).into()); - } - } - Ok(()) - } - - /// Remove an IPv6 address/CIDR from a set - fn remove_from_set_v6(&self, set_name: &str, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { - let addr = if prefixlen == 128 { - ip.to_string() - } else { - format!("{}/{}", ip, prefixlen) - }; - - let output = nft_cmd() - .args(["delete", "element", "inet", NFT_TABLE_NAME, set_name, &format!("{{ {} }}", addr)]) - .output()?; - - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr); - if !stderr.contains("No such") && !stderr.contains("does not exist") { - return Err(format!("Failed to remove {} from {}: {}", addr, set_name, stderr).into()); - } - } - Ok(()) - } - - /// Check if an IPv6 address exists in a set - fn exists_in_set_v6(&self, set_name: &str, ip: Ipv6Addr) -> Result> { - let output = nft_cmd() - .args(["list", "set", "inet", NFT_TABLE_NAME, set_name]) - .output()?; - - if output.status.success() { - let stdout = String::from_utf8_lossy(&output.stdout); - Ok(stdout.contains(&ip.to_string())) - } else { - Ok(false) - } - } - - /// Clean up nftables rules on shutdown - pub fn cleanup(&self) -> Result<(), Box> { - log::info!("Cleaning up nftables firewall rules"); - let _ = nft_cmd() - .args(["delete", "table", "inet", NFT_TABLE_NAME]) - .output(); - NFTABLES_INITIALIZED.store(false, Ordering::SeqCst); - Ok(()) - } -} - -impl Default for NftablesFirewall { - fn default() -> Self { - Self::new().unwrap_or(Self { initialized: false }) - } -} - -impl Firewall for NftablesFirewall { - fn ban_ip_with_notice(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { - self.add_to_set_v4(NFT_SET_NOTICE_IPV4, ip, prefixlen)?; - self.add_to_set_v4(NFT_SET_BANNED_IPV4, ip, prefixlen)?; - log::debug!("nftables: banned IPv4 {}/{} with notice", ip, prefixlen); - Ok(()) - } - - fn ban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { - self.add_to_set_v4(NFT_SET_BANNED_IPV4, ip, prefixlen)?; - log::debug!("nftables: banned IPv4 {}/{}", ip, prefixlen); - Ok(()) - } - - fn unban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { - self.remove_from_set_v4(NFT_SET_BANNED_IPV4, ip, prefixlen)?; - self.remove_from_set_v4(NFT_SET_NOTICE_IPV4, ip, prefixlen)?; - log::debug!("nftables: unbanned IPv4 {}/{}", ip, prefixlen); - Ok(()) - } - - fn check_if_notice(&mut self, ip: Ipv4Addr) -> Result> { - self.exists_in_set_v4(NFT_SET_NOTICE_IPV4, ip) - } - - fn ban_ipv6_with_notice(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { - self.add_to_set_v6(NFT_SET_NOTICE_IPV6, ip, prefixlen)?; - self.add_to_set_v6(NFT_SET_BANNED_IPV6, ip, prefixlen)?; - log::debug!("nftables: banned IPv6 {}/{} with notice", ip, prefixlen); - Ok(()) - } - - fn ban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { - self.add_to_set_v6(NFT_SET_BANNED_IPV6, ip, prefixlen)?; - log::debug!("nftables: banned IPv6 {}/{}", ip, prefixlen); - Ok(()) - } - - fn unban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { - self.remove_from_set_v6(NFT_SET_BANNED_IPV6, ip, prefixlen)?; - self.remove_from_set_v6(NFT_SET_NOTICE_IPV6, ip, prefixlen)?; - log::debug!("nftables: unbanned IPv6 {}/{}", ip, prefixlen); - Ok(()) - } - - fn check_if_notice_ipv6(&mut self, ip: Ipv6Addr) -> Result> { - self.exists_in_set_v6(NFT_SET_NOTICE_IPV6, ip) - } - - // TCP fingerprint blocking is not supported via nftables (requires BPF) - // These are no-ops in the nftables fallback - fn block_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint blocking not supported in nftables fallback mode (fingerprint: {})", fingerprint); - Ok(()) - } - - fn unblock_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint unblocking not supported in nftables fallback mode (fingerprint: {})", fingerprint); - Ok(()) - } - - fn block_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint blocking (IPv6) not supported in nftables fallback mode (fingerprint: {})", fingerprint); - Ok(()) - } - - fn unblock_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint unblocking (IPv6) not supported in nftables fallback mode (fingerprint: {})", fingerprint); - Ok(()) - } - - fn is_tcp_fingerprint_blocked(&self, _fingerprint: &str) -> Result> { - // TCP fingerprint blocking not supported in nftables mode - Ok(false) - } - - fn is_tcp_fingerprint_blocked_v6(&self, _fingerprint: &str) -> Result> { - // TCP fingerprint blocking not supported in nftables mode - Ok(false) - } -} - -impl Drop for NftablesFirewall { - fn drop(&mut self) { - // Don't cleanup on drop - rules should persist - // Call cleanup() explicitly if needed - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_nftables_available() { - // This test just checks if the function runs without panic - let _ = NftablesFirewall::is_available(); - } -} +use std::error::Error; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::process::Command; +use std::sync::atomic::{AtomicBool, Ordering}; + +use super::Firewall; + +/// Path to nft binary - try common locations +const NFT_PATHS: &[&str] = &["/usr/bin/nft", "/usr/sbin/nft", "/sbin/nft", "nft"]; + +/// Table and chain names for synapse nftables rules +const NFT_TABLE_NAME: &str = "synapse"; +const NFT_CHAIN_INPUT: &str = "synapse_input"; +const NFT_SET_BANNED_IPV4: &str = "banned_ips_v4"; +const NFT_SET_BANNED_IPV6: &str = "banned_ips_v6"; +const NFT_SET_NOTICE_IPV4: &str = "notice_ips_v4"; +const NFT_SET_NOTICE_IPV6: &str = "notice_ips_v6"; + +/// Find the nft binary path +fn find_nft_path() -> Option<&'static str> { + for path in NFT_PATHS { + if Command::new(path) + .arg("--version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + { + return Some(path); + } + } + None +} + +/// Get the nft command - caches the path after first lookup +fn nft_cmd() -> Command { + static NFT_PATH: std::sync::OnceLock> = std::sync::OnceLock::new(); + let path = NFT_PATH.get_or_init(find_nft_path); + Command::new(path.unwrap_or("nft")) +} + +static NFTABLES_INITIALIZED: AtomicBool = AtomicBool::new(false); + +/// Nftables-based firewall implementation for when BPF/XDP is not available +pub struct NftablesFirewall { + initialized: bool, +} + +impl NftablesFirewall { + pub fn new() -> Result> { + let mut fw = Self { initialized: false }; + fw.initialize()?; + Ok(fw) + } + + /// Check if nftables is available on the system + pub fn is_available() -> bool { + find_nft_path().is_some() + } + + /// Initialize the nftables table, chains, and sets + fn initialize(&mut self) -> Result<(), Box> { + if NFTABLES_INITIALIZED.load(Ordering::SeqCst) { + self.initialized = true; + return Ok(()); + } + + if !Self::is_available() { + return Err("nftables (nft) command not found on system".into()); + } + + log::info!("Initializing nftables firewall (XDP/BPF fallback)"); + + // Use nft command directly for more reliable initialization + self.init_with_nft_command()?; + + NFTABLES_INITIALIZED.store(true, Ordering::SeqCst); + self.initialized = true; + log::info!("Nftables firewall initialized successfully"); + Ok(()) + } + + /// Initialize using nft command directly (more reliable) + fn init_with_nft_command(&self) -> Result<(), Box> { + // Use nft -f with heredoc-style input for reliable parsing + let nft_script = format!(r#" +table inet {table} {{ + set {set_v4} {{ + type ipv4_addr + flags interval + }} + set {set_v6} {{ + type ipv6_addr + flags interval + }} + set {notice_v4} {{ + type ipv4_addr + flags interval + }} + set {notice_v6} {{ + type ipv6_addr + flags interval + }} + chain {chain} {{ + type filter hook input priority -100; policy accept; + ip saddr @{set_v4} drop + ip6 saddr @{set_v6} drop + }} +}} +"#, + table = NFT_TABLE_NAME, + set_v4 = NFT_SET_BANNED_IPV4, + set_v6 = NFT_SET_BANNED_IPV6, + notice_v4 = NFT_SET_NOTICE_IPV4, + notice_v6 = NFT_SET_NOTICE_IPV6, + chain = NFT_CHAIN_INPUT, + ); + + // First try to delete existing table (ignore errors) + let _ = nft_cmd() + .args(["delete", "table", "inet", NFT_TABLE_NAME]) + .output(); + + // Create table with all sets and chains + let output = nft_cmd() + .args(["-f", "-"]) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .spawn() + .and_then(|mut child| { + use std::io::Write; + if let Some(stdin) = child.stdin.as_mut() { + stdin.write_all(nft_script.as_bytes())?; + } + child.wait_with_output() + })?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + + // Check for common error patterns and provide user-friendly messages + if stderr.contains("Operation not supported") { + log::debug!("nftables kernel support not available: {}", stderr); + return Err("nftables kernel support not available (Operation not supported) - this is common in containers or systems without nf_tables kernel module".into()); + } else if stderr.contains("Permission denied") { + return Err("nftables permission denied - ensure synapse runs as root with CAP_NET_ADMIN".into()); + } else if stderr.contains("No such file or directory") { + return Err("nftables failed - nft command or required files not found".into()); + } else { + log::debug!("nftables initialization failed: {}", stderr); + return Err(format!("nftables initialization failed: {}", stderr.lines().next().unwrap_or("unknown error")).into()); + } + } + + Ok(()) + } + + /// Add an IPv4 address/CIDR to a set + fn add_to_set_v4(&self, set_name: &str, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { + let addr = if prefixlen == 32 { + ip.to_string() + } else { + format!("{}/{}", ip, prefixlen) + }; + + let element = format!("{{ {} }}", addr); + let output = nft_cmd() + .args(["add", "element", "inet", NFT_TABLE_NAME, set_name, &element]) + .output()?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + // Ignore "already exists" errors + if !stderr.contains("exists") && !stdout.contains("exists") { + let error_msg = if stderr.is_empty() && stdout.is_empty() { + format!("exit code: {:?}", output.status.code()) + } else if stderr.is_empty() { + stdout.to_string() + } else { + stderr.to_string() + }; + return Err(format!("Failed to add {} to {}: {}", addr, set_name, error_msg).into()); + } + } + Ok(()) + } + + /// Remove an IPv4 address/CIDR from a set + fn remove_from_set_v4(&self, set_name: &str, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { + let addr = if prefixlen == 32 { + ip.to_string() + } else { + format!("{}/{}", ip, prefixlen) + }; + + let output = nft_cmd() + .args(["delete", "element", "inet", NFT_TABLE_NAME, set_name, &format!("{{ {} }}", addr)]) + .output()?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + // Ignore "not found" errors + if !stderr.contains("No such") && !stderr.contains("does not exist") { + return Err(format!("Failed to remove {} from {}: {}", addr, set_name, stderr).into()); + } + } + Ok(()) + } + + /// Check if an IPv4 address exists in a set + fn exists_in_set_v4(&self, set_name: &str, ip: Ipv4Addr) -> Result> { + let output = nft_cmd() + .args(["list", "set", "inet", NFT_TABLE_NAME, set_name]) + .output()?; + + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout); + Ok(stdout.contains(&ip.to_string())) + } else { + Ok(false) + } + } + + /// Add an IPv6 address/CIDR to a set + fn add_to_set_v6(&self, set_name: &str, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { + let addr = if prefixlen == 128 { + ip.to_string() + } else { + format!("{}/{}", ip, prefixlen) + }; + + let element = format!("{{ {} }}", addr); + let output = nft_cmd() + .args(["add", "element", "inet", NFT_TABLE_NAME, set_name, &element]) + .output()?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + if !stderr.contains("exists") && !stdout.contains("exists") { + let error_msg = if stderr.is_empty() && stdout.is_empty() { + format!("exit code: {:?}", output.status.code()) + } else if stderr.is_empty() { + stdout.to_string() + } else { + stderr.to_string() + }; + return Err(format!("Failed to add {} to {}: {}", addr, set_name, error_msg).into()); + } + } + Ok(()) + } + + /// Remove an IPv6 address/CIDR from a set + fn remove_from_set_v6(&self, set_name: &str, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { + let addr = if prefixlen == 128 { + ip.to_string() + } else { + format!("{}/{}", ip, prefixlen) + }; + + let output = nft_cmd() + .args(["delete", "element", "inet", NFT_TABLE_NAME, set_name, &format!("{{ {} }}", addr)]) + .output()?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + if !stderr.contains("No such") && !stderr.contains("does not exist") { + return Err(format!("Failed to remove {} from {}: {}", addr, set_name, stderr).into()); + } + } + Ok(()) + } + + /// Check if an IPv6 address exists in a set + fn exists_in_set_v6(&self, set_name: &str, ip: Ipv6Addr) -> Result> { + let output = nft_cmd() + .args(["list", "set", "inet", NFT_TABLE_NAME, set_name]) + .output()?; + + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout); + Ok(stdout.contains(&ip.to_string())) + } else { + Ok(false) + } + } + + /// Clean up nftables rules on shutdown + pub fn cleanup(&self) -> Result<(), Box> { + log::info!("Cleaning up nftables firewall rules"); + let _ = nft_cmd() + .args(["delete", "table", "inet", NFT_TABLE_NAME]) + .output(); + NFTABLES_INITIALIZED.store(false, Ordering::SeqCst); + Ok(()) + } +} + +impl Default for NftablesFirewall { + fn default() -> Self { + Self::new().unwrap_or(Self { initialized: false }) + } +} + +impl Firewall for NftablesFirewall { + fn ban_ip_with_notice(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { + self.add_to_set_v4(NFT_SET_NOTICE_IPV4, ip, prefixlen)?; + self.add_to_set_v4(NFT_SET_BANNED_IPV4, ip, prefixlen)?; + log::debug!("nftables: banned IPv4 {}/{} with notice", ip, prefixlen); + Ok(()) + } + + fn ban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { + self.add_to_set_v4(NFT_SET_BANNED_IPV4, ip, prefixlen)?; + log::debug!("nftables: banned IPv4 {}/{}", ip, prefixlen); + Ok(()) + } + + fn unban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { + self.remove_from_set_v4(NFT_SET_BANNED_IPV4, ip, prefixlen)?; + self.remove_from_set_v4(NFT_SET_NOTICE_IPV4, ip, prefixlen)?; + log::debug!("nftables: unbanned IPv4 {}/{}", ip, prefixlen); + Ok(()) + } + + fn check_if_notice(&mut self, ip: Ipv4Addr) -> Result> { + self.exists_in_set_v4(NFT_SET_NOTICE_IPV4, ip) + } + + fn ban_ipv6_with_notice(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { + self.add_to_set_v6(NFT_SET_NOTICE_IPV6, ip, prefixlen)?; + self.add_to_set_v6(NFT_SET_BANNED_IPV6, ip, prefixlen)?; + log::debug!("nftables: banned IPv6 {}/{} with notice", ip, prefixlen); + Ok(()) + } + + fn ban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { + self.add_to_set_v6(NFT_SET_BANNED_IPV6, ip, prefixlen)?; + log::debug!("nftables: banned IPv6 {}/{}", ip, prefixlen); + Ok(()) + } + + fn unban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { + self.remove_from_set_v6(NFT_SET_BANNED_IPV6, ip, prefixlen)?; + self.remove_from_set_v6(NFT_SET_NOTICE_IPV6, ip, prefixlen)?; + log::debug!("nftables: unbanned IPv6 {}/{}", ip, prefixlen); + Ok(()) + } + + fn check_if_notice_ipv6(&mut self, ip: Ipv6Addr) -> Result> { + self.exists_in_set_v6(NFT_SET_NOTICE_IPV6, ip) + } + + // TCP fingerprint blocking is not supported via nftables (requires BPF) + // These are no-ops in the nftables fallback + fn block_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { + log::warn!("TCP fingerprint blocking not supported in nftables fallback mode (fingerprint: {})", fingerprint); + Ok(()) + } + + fn unblock_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { + log::warn!("TCP fingerprint unblocking not supported in nftables fallback mode (fingerprint: {})", fingerprint); + Ok(()) + } + + fn block_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { + log::warn!("TCP fingerprint blocking (IPv6) not supported in nftables fallback mode (fingerprint: {})", fingerprint); + Ok(()) + } + + fn unblock_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { + log::warn!("TCP fingerprint unblocking (IPv6) not supported in nftables fallback mode (fingerprint: {})", fingerprint); + Ok(()) + } + + fn is_tcp_fingerprint_blocked(&self, _fingerprint: &str) -> Result> { + // TCP fingerprint blocking not supported in nftables mode + Ok(false) + } + + fn is_tcp_fingerprint_blocked_v6(&self, _fingerprint: &str) -> Result> { + // TCP fingerprint blocking not supported in nftables mode + Ok(false) + } +} + +impl Drop for NftablesFirewall { + fn drop(&mut self) { + // Don't cleanup on drop - rules should persist + // Call cleanup() explicitly if needed + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_nftables_available() { + // This test just checks if the function runs without panic + let _ = NftablesFirewall::is_available(); + } +} diff --git a/src/firewall_noop.rs b/src/firewall_noop.rs index 981cae6..8074bd3 100644 --- a/src/firewall_noop.rs +++ b/src/firewall_noop.rs @@ -1,290 +1,290 @@ -use std::error::Error; -use std::marker::PhantomData; -use std::net::{Ipv4Addr, Ipv6Addr}; - -use serde::{Deserialize, Serialize}; - -/// Enum to represent the active firewall backend -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum FirewallBackend { - Xdp, - Nftables, - Iptables, - None, -} - -impl std::fmt::Display for FirewallBackend { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - FirewallBackend::Xdp => write!(f, "XDP/BPF"), - FirewallBackend::Nftables => write!(f, "nftables"), - FirewallBackend::Iptables => write!(f, "iptables"), - FirewallBackend::None => write!(f, "none (userland)"), - } - } -} - -/// Configuration option for forcing a specific firewall backend -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)] -#[serde(rename_all = "lowercase")] -pub enum FirewallMode { - #[default] - Auto, - Xdp, - Nftables, - Iptables, - None, -} - -impl std::fmt::Display for FirewallMode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - FirewallMode::Auto => write!(f, "auto"), - FirewallMode::Xdp => write!(f, "xdp"), - FirewallMode::Nftables => write!(f, "nftables"), - FirewallMode::Iptables => write!(f, "iptables"), - FirewallMode::None => write!(f, "none"), - } - } -} - -pub trait Firewall { - fn ban_ip_with_notice(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box>; - fn ban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box>; - fn unban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box>; - fn check_if_notice(&mut self, ip: Ipv4Addr) -> Result>; - - // IPv6 methods - fn ban_ipv6_with_notice(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box>; - fn ban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box>; - fn unban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box>; - fn check_if_notice_ipv6(&mut self, ip: Ipv6Addr) -> Result>; - - // TCP fingerprint blocking methods - fn block_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box>; - fn unblock_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box>; - fn block_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box>; - fn unblock_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box>; - fn is_tcp_fingerprint_blocked(&self, fingerprint: &str) -> Result>; - fn is_tcp_fingerprint_blocked_v6(&self, fingerprint: &str) -> Result>; -} - -pub struct SYNAPSEFirewall<'a> { - _skel: &'a crate::bpf::FilterSkel<'a>, -} - -impl<'a> SYNAPSEFirewall<'a> { - pub fn new(skel: &'a crate::bpf::FilterSkel<'a>) -> Self { - Self { _skel: skel } - } -} - -impl<'a> Firewall for SYNAPSEFirewall<'a> { - fn ban_ip_with_notice(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn ban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn unban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn check_if_notice(&mut self, _ip: Ipv4Addr) -> Result> { - Ok(false) - } - - fn ban_ipv6_with_notice(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn ban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn unban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn check_if_notice_ipv6(&mut self, _ip: Ipv6Addr) -> Result> { - Ok(false) - } - - fn block_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { - Ok(()) - } - - fn unblock_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { - Ok(()) - } - - fn block_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { - Ok(()) - } - - fn unblock_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { - Ok(()) - } - - fn is_tcp_fingerprint_blocked(&self, _fingerprint: &str) -> Result> { - Ok(false) - } - - fn is_tcp_fingerprint_blocked_v6(&self, _fingerprint: &str) -> Result> { - Ok(false) - } -} - -pub struct NftablesFirewall { - _marker: PhantomData<()>, -} - -impl NftablesFirewall { - pub fn new() -> Result> { - Ok(Self { _marker: PhantomData }) - } - - pub fn is_available() -> bool { - false - } - - pub fn cleanup(&self) -> Result<(), Box> { - Ok(()) - } -} - -impl Firewall for NftablesFirewall { - fn ban_ip_with_notice(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn ban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn unban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn check_if_notice(&mut self, _ip: Ipv4Addr) -> Result> { - Ok(false) - } - - fn ban_ipv6_with_notice(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn ban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn unban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn check_if_notice_ipv6(&mut self, _ip: Ipv6Addr) -> Result> { - Ok(false) - } - - fn block_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { - Ok(()) - } - - fn unblock_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { - Ok(()) - } - - fn block_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { - Ok(()) - } - - fn unblock_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { - Ok(()) - } - - fn is_tcp_fingerprint_blocked(&self, _fingerprint: &str) -> Result> { - Ok(false) - } - - fn is_tcp_fingerprint_blocked_v6(&self, _fingerprint: &str) -> Result> { - Ok(false) - } -} - -pub struct IptablesFirewall { - _marker: PhantomData<()>, -} - -impl IptablesFirewall { - pub fn new() -> Result> { - Ok(Self { _marker: PhantomData }) - } - - pub fn is_available() -> bool { - false - } - - pub fn cleanup(&self) -> Result<(), Box> { - Ok(()) - } -} - -impl Firewall for IptablesFirewall { - fn ban_ip_with_notice(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn ban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn unban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn check_if_notice(&mut self, _ip: Ipv4Addr) -> Result> { - Ok(false) - } - - fn ban_ipv6_with_notice(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn ban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn unban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { - Ok(()) - } - - fn check_if_notice_ipv6(&mut self, _ip: Ipv6Addr) -> Result> { - Ok(false) - } - - fn block_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { - Ok(()) - } - - fn unblock_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { - Ok(()) - } - - fn block_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { - Ok(()) - } - - fn unblock_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { - Ok(()) - } - - fn is_tcp_fingerprint_blocked(&self, _fingerprint: &str) -> Result> { - Ok(false) - } - - fn is_tcp_fingerprint_blocked_v6(&self, _fingerprint: &str) -> Result> { - Ok(false) - } -} +use std::error::Error; +use std::marker::PhantomData; +use std::net::{Ipv4Addr, Ipv6Addr}; + +use serde::{Deserialize, Serialize}; + +/// Enum to represent the active firewall backend +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FirewallBackend { + Xdp, + Nftables, + Iptables, + None, +} + +impl std::fmt::Display for FirewallBackend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FirewallBackend::Xdp => write!(f, "XDP/BPF"), + FirewallBackend::Nftables => write!(f, "nftables"), + FirewallBackend::Iptables => write!(f, "iptables"), + FirewallBackend::None => write!(f, "none (userland)"), + } + } +} + +/// Configuration option for forcing a specific firewall backend +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum FirewallMode { + #[default] + Auto, + Xdp, + Nftables, + Iptables, + None, +} + +impl std::fmt::Display for FirewallMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FirewallMode::Auto => write!(f, "auto"), + FirewallMode::Xdp => write!(f, "xdp"), + FirewallMode::Nftables => write!(f, "nftables"), + FirewallMode::Iptables => write!(f, "iptables"), + FirewallMode::None => write!(f, "none"), + } + } +} + +pub trait Firewall { + fn ban_ip_with_notice(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box>; + fn ban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box>; + fn unban_ip(&mut self, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box>; + fn check_if_notice(&mut self, ip: Ipv4Addr) -> Result>; + + // IPv6 methods + fn ban_ipv6_with_notice(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box>; + fn ban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box>; + fn unban_ipv6(&mut self, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box>; + fn check_if_notice_ipv6(&mut self, ip: Ipv6Addr) -> Result>; + + // TCP fingerprint blocking methods + fn block_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box>; + fn unblock_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box>; + fn block_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box>; + fn unblock_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box>; + fn is_tcp_fingerprint_blocked(&self, fingerprint: &str) -> Result>; + fn is_tcp_fingerprint_blocked_v6(&self, fingerprint: &str) -> Result>; +} + +pub struct SYNAPSEFirewall<'a> { + _skel: &'a crate::bpf::FilterSkel<'a>, +} + +impl<'a> SYNAPSEFirewall<'a> { + pub fn new(skel: &'a crate::bpf::FilterSkel<'a>) -> Self { + Self { _skel: skel } + } +} + +impl<'a> Firewall for SYNAPSEFirewall<'a> { + fn ban_ip_with_notice(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn ban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn unban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn check_if_notice(&mut self, _ip: Ipv4Addr) -> Result> { + Ok(false) + } + + fn ban_ipv6_with_notice(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn ban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn unban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn check_if_notice_ipv6(&mut self, _ip: Ipv6Addr) -> Result> { + Ok(false) + } + + fn block_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn unblock_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn block_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn unblock_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn is_tcp_fingerprint_blocked(&self, _fingerprint: &str) -> Result> { + Ok(false) + } + + fn is_tcp_fingerprint_blocked_v6(&self, _fingerprint: &str) -> Result> { + Ok(false) + } +} + +pub struct NftablesFirewall { + _marker: PhantomData<()>, +} + +impl NftablesFirewall { + pub fn new() -> Result> { + Ok(Self { _marker: PhantomData }) + } + + pub fn is_available() -> bool { + false + } + + pub fn cleanup(&self) -> Result<(), Box> { + Ok(()) + } +} + +impl Firewall for NftablesFirewall { + fn ban_ip_with_notice(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn ban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn unban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn check_if_notice(&mut self, _ip: Ipv4Addr) -> Result> { + Ok(false) + } + + fn ban_ipv6_with_notice(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn ban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn unban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn check_if_notice_ipv6(&mut self, _ip: Ipv6Addr) -> Result> { + Ok(false) + } + + fn block_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn unblock_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn block_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn unblock_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn is_tcp_fingerprint_blocked(&self, _fingerprint: &str) -> Result> { + Ok(false) + } + + fn is_tcp_fingerprint_blocked_v6(&self, _fingerprint: &str) -> Result> { + Ok(false) + } +} + +pub struct IptablesFirewall { + _marker: PhantomData<()>, +} + +impl IptablesFirewall { + pub fn new() -> Result> { + Ok(Self { _marker: PhantomData }) + } + + pub fn is_available() -> bool { + false + } + + pub fn cleanup(&self) -> Result<(), Box> { + Ok(()) + } +} + +impl Firewall for IptablesFirewall { + fn ban_ip_with_notice(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn ban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn unban_ip(&mut self, _ip: Ipv4Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn check_if_notice(&mut self, _ip: Ipv4Addr) -> Result> { + Ok(false) + } + + fn ban_ipv6_with_notice(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn ban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn unban_ipv6(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + Ok(()) + } + + fn check_if_notice_ipv6(&mut self, _ip: Ipv6Addr) -> Result> { + Ok(false) + } + + fn block_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn unblock_tcp_fingerprint(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn block_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn unblock_tcp_fingerprint_v6(&mut self, _fingerprint: &str) -> Result<(), Box> { + Ok(()) + } + + fn is_tcp_fingerprint_blocked(&self, _fingerprint: &str) -> Result> { + Ok(false) + } + + fn is_tcp_fingerprint_blocked_v6(&self, _fingerprint: &str) -> Result> { + Ok(false) + } +} diff --git a/src/http_client.rs b/src/http_client.rs index 23a8acc..42c75ab 100644 --- a/src/http_client.rs +++ b/src/http_client.rs @@ -1,155 +1,155 @@ -use std::sync::Arc; -use std::time::Duration; -use reqwest::Client; -use anyhow::{Context, Result}; - -/// Shared HTTP client configuration with keepalive settings -#[derive(Debug, Clone)] -pub struct HttpClientConfig { - pub timeout: Duration, - pub connect_timeout: Duration, - pub keepalive_timeout: Duration, - pub max_idle_per_host: usize, - pub user_agent: String, - pub danger_accept_invalid_certs: bool, -} - -impl Default for HttpClientConfig { - fn default() -> Self { - Self { - timeout: Duration::from_secs(30), - connect_timeout: Duration::from_secs(10), - keepalive_timeout: Duration::from_secs(60), // Keep connections alive for 60 seconds - max_idle_per_host: 10, // Allow up to 10 idle connections per host - user_agent: format!("Synapse/{}", env!("CARGO_PKG_VERSION")), - danger_accept_invalid_certs: false, - } - } -} - -/// Shared HTTP client with keepalive configuration -pub struct SharedHttpClient { - client: Arc, - config: HttpClientConfig, -} - -impl SharedHttpClient { - /// Create a new shared HTTP client with the given configuration - pub fn new(config: HttpClientConfig) -> Result { - let client = Client::builder() - .timeout(config.timeout) - .connect_timeout(config.connect_timeout) - .tcp_keepalive(config.keepalive_timeout) - .pool_max_idle_per_host(config.max_idle_per_host) - .user_agent(&config.user_agent) - .danger_accept_invalid_certs(config.danger_accept_invalid_certs) - .build() - .context("Failed to create HTTP client with keepalive configuration")?; - - Ok(Self { - client: Arc::new(client), - config, - }) - } - - /// Create a new shared HTTP client with default configuration - pub fn with_defaults() -> Result { - Self::new(HttpClientConfig::default()) - } - - /// Get a reference to the underlying HTTP client - pub fn client(&self) -> &Client { - &self.client - } - - /// Get a clone of the client Arc for sharing across threads - pub fn client_arc(&self) -> Arc { - self.client.clone() - } - - /// Get the current configuration - pub fn config(&self) -> &HttpClientConfig { - &self.config - } - - /// Update the configuration and recreate the client - pub fn update_config(&mut self, config: HttpClientConfig) -> Result<()> { - let client = Client::builder() - .timeout(config.timeout) - .connect_timeout(config.connect_timeout) - .tcp_keepalive(config.keepalive_timeout) - .pool_max_idle_per_host(config.max_idle_per_host) - .user_agent(&config.user_agent) - .danger_accept_invalid_certs(config.danger_accept_invalid_certs) - .build() - .context("Failed to recreate HTTP client with new configuration")?; - - self.client = Arc::new(client); - self.config = config; - Ok(()) - } -} - -/// Global shared HTTP client instance -static GLOBAL_HTTP_CLIENT: std::sync::OnceLock> = std::sync::OnceLock::new(); - -/// Initialize the global HTTP client with default configuration -pub fn init_global_client() -> Result<()> { - let client = SharedHttpClient::with_defaults()?; - GLOBAL_HTTP_CLIENT - .set(Arc::new(client)) - .map_err(|_| anyhow::anyhow!("Global HTTP client already initialized"))?; - Ok(()) -} - -/// Initialize the global HTTP client with custom configuration -pub fn init_global_client_with_config(config: HttpClientConfig) -> Result<()> { - let client = SharedHttpClient::new(config)?; - GLOBAL_HTTP_CLIENT - .set(Arc::new(client)) - .map_err(|_| anyhow::anyhow!("Global HTTP client already initialized"))?; - Ok(()) -} - -/// Get a reference to the global HTTP client -pub fn get_global_client() -> Result> { - GLOBAL_HTTP_CLIENT - .get() - .cloned() - .ok_or_else(|| anyhow::anyhow!("Global HTTP client not initialized")) -} - -/// Get a reference to the underlying reqwest Client from the global client -pub fn get_global_reqwest_client() -> Result> { - let shared_client = get_global_client()?; - Ok(shared_client.client_arc()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_http_client_config_default() { - let config = HttpClientConfig::default(); - assert_eq!(config.timeout, Duration::from_secs(30)); - assert_eq!(config.connect_timeout, Duration::from_secs(10)); - assert_eq!(config.keepalive_timeout, Duration::from_secs(60)); - assert_eq!(config.max_idle_per_host, 10); - assert_eq!(config.user_agent, format!("Synapse/{}", env!("CARGO_PKG_VERSION"))); - assert!(!config.danger_accept_invalid_certs); - } - - #[test] - fn test_shared_http_client_creation() { - let config = HttpClientConfig::default(); - let client = SharedHttpClient::new(config).unwrap(); - assert_eq!(client.config().user_agent, format!("Synapse/{}", env!("CARGO_PKG_VERSION"))); - } - - #[test] - fn test_shared_http_client_with_defaults() { - let client = SharedHttpClient::with_defaults().unwrap(); - assert_eq!(client.config().user_agent, format!("Synapse/{}", env!("CARGO_PKG_VERSION"))); - } -} +use std::sync::Arc; +use std::time::Duration; +use reqwest::Client; +use anyhow::{Context, Result}; + +/// Shared HTTP client configuration with keepalive settings +#[derive(Debug, Clone)] +pub struct HttpClientConfig { + pub timeout: Duration, + pub connect_timeout: Duration, + pub keepalive_timeout: Duration, + pub max_idle_per_host: usize, + pub user_agent: String, + pub danger_accept_invalid_certs: bool, +} + +impl Default for HttpClientConfig { + fn default() -> Self { + Self { + timeout: Duration::from_secs(30), + connect_timeout: Duration::from_secs(10), + keepalive_timeout: Duration::from_secs(60), // Keep connections alive for 60 seconds + max_idle_per_host: 10, // Allow up to 10 idle connections per host + user_agent: format!("Synapse/{}", env!("CARGO_PKG_VERSION")), + danger_accept_invalid_certs: false, + } + } +} + +/// Shared HTTP client with keepalive configuration +pub struct SharedHttpClient { + client: Arc, + config: HttpClientConfig, +} + +impl SharedHttpClient { + /// Create a new shared HTTP client with the given configuration + pub fn new(config: HttpClientConfig) -> Result { + let client = Client::builder() + .timeout(config.timeout) + .connect_timeout(config.connect_timeout) + .tcp_keepalive(config.keepalive_timeout) + .pool_max_idle_per_host(config.max_idle_per_host) + .user_agent(&config.user_agent) + .danger_accept_invalid_certs(config.danger_accept_invalid_certs) + .build() + .context("Failed to create HTTP client with keepalive configuration")?; + + Ok(Self { + client: Arc::new(client), + config, + }) + } + + /// Create a new shared HTTP client with default configuration + pub fn with_defaults() -> Result { + Self::new(HttpClientConfig::default()) + } + + /// Get a reference to the underlying HTTP client + pub fn client(&self) -> &Client { + &self.client + } + + /// Get a clone of the client Arc for sharing across threads + pub fn client_arc(&self) -> Arc { + self.client.clone() + } + + /// Get the current configuration + pub fn config(&self) -> &HttpClientConfig { + &self.config + } + + /// Update the configuration and recreate the client + pub fn update_config(&mut self, config: HttpClientConfig) -> Result<()> { + let client = Client::builder() + .timeout(config.timeout) + .connect_timeout(config.connect_timeout) + .tcp_keepalive(config.keepalive_timeout) + .pool_max_idle_per_host(config.max_idle_per_host) + .user_agent(&config.user_agent) + .danger_accept_invalid_certs(config.danger_accept_invalid_certs) + .build() + .context("Failed to recreate HTTP client with new configuration")?; + + self.client = Arc::new(client); + self.config = config; + Ok(()) + } +} + +/// Global shared HTTP client instance +static GLOBAL_HTTP_CLIENT: std::sync::OnceLock> = std::sync::OnceLock::new(); + +/// Initialize the global HTTP client with default configuration +pub fn init_global_client() -> Result<()> { + let client = SharedHttpClient::with_defaults()?; + GLOBAL_HTTP_CLIENT + .set(Arc::new(client)) + .map_err(|_| anyhow::anyhow!("Global HTTP client already initialized"))?; + Ok(()) +} + +/// Initialize the global HTTP client with custom configuration +pub fn init_global_client_with_config(config: HttpClientConfig) -> Result<()> { + let client = SharedHttpClient::new(config)?; + GLOBAL_HTTP_CLIENT + .set(Arc::new(client)) + .map_err(|_| anyhow::anyhow!("Global HTTP client already initialized"))?; + Ok(()) +} + +/// Get a reference to the global HTTP client +pub fn get_global_client() -> Result> { + GLOBAL_HTTP_CLIENT + .get() + .cloned() + .ok_or_else(|| anyhow::anyhow!("Global HTTP client not initialized")) +} + +/// Get a reference to the underlying reqwest Client from the global client +pub fn get_global_reqwest_client() -> Result> { + let shared_client = get_global_client()?; + Ok(shared_client.client_arc()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_http_client_config_default() { + let config = HttpClientConfig::default(); + assert_eq!(config.timeout, Duration::from_secs(30)); + assert_eq!(config.connect_timeout, Duration::from_secs(10)); + assert_eq!(config.keepalive_timeout, Duration::from_secs(60)); + assert_eq!(config.max_idle_per_host, 10); + assert_eq!(config.user_agent, format!("Synapse/{}", env!("CARGO_PKG_VERSION"))); + assert!(!config.danger_accept_invalid_certs); + } + + #[test] + fn test_shared_http_client_creation() { + let config = HttpClientConfig::default(); + let client = SharedHttpClient::new(config).unwrap(); + assert_eq!(client.config().user_agent, format!("Synapse/{}", env!("CARGO_PKG_VERSION"))); + } + + #[test] + fn test_shared_http_client_with_defaults() { + let client = SharedHttpClient::with_defaults().unwrap(); + assert_eq!(client.config().user_agent, format!("Synapse/{}", env!("CARGO_PKG_VERSION"))); + } +} diff --git a/src/http_proxy.rs b/src/http_proxy.rs index 6172525..a5211b7 100644 --- a/src/http_proxy.rs +++ b/src/http_proxy.rs @@ -1,5 +1,5 @@ -pub mod bgservice; -pub mod gethosts; -pub mod proxyhttp; -pub mod start; -pub mod webserver; +pub mod bgservice; +pub mod gethosts; +pub mod proxyhttp; +pub mod start; +pub mod webserver; diff --git a/src/http_proxy/bgservice.rs b/src/http_proxy/bgservice.rs index 1a8a03a..e1e02bd 100644 --- a/src/http_proxy/bgservice.rs +++ b/src/http_proxy/bgservice.rs @@ -1,255 +1,255 @@ -use crate::utils::discovery::{APIUpstreamProvider, Discovery, FromFileProvider}; -use crate::utils::parceyaml::load_configuration; -use crate::utils::structs::Configuration; -use crate::utils::healthcheck; -use crate::http_proxy::proxyhttp::LB; -use crate::worker::certificate::{request_certificate_from_acme, get_acme_config}; -use async_trait::async_trait; -use dashmap::DashMap; -use futures::channel::mpsc; -use futures::{SinkExt, StreamExt}; -use log::{debug, error, info, warn}; -use pingora_core::server::ShutdownWatch; -use pingora_core::services::background::BackgroundService; -use std::sync::Arc; -use crate::redis::RedisManager; - -#[async_trait] -impl BackgroundService for LB { - async fn start(&self, mut shutdown: ShutdownWatch) { - info!("Starting Pingora background service for upstreams management"); - let (mut tx, mut rx) = mpsc::channel::(1); - let tx_api = tx.clone(); - - // Skip if no upstreams config file is provided (e.g., when using new config format) - if self.config.upstreams_conf.is_empty() { - info!("No upstreams config file specified, Pingora proxy system not initialized"); - return; - } - - info!("Loading upstreams configuration from: {}", self.config.upstreams_conf); - let config = match load_configuration(self.config.upstreams_conf.clone().as_str(), "filepath").await { - Some(cfg) => { - info!("Upstreams configuration loaded successfully"); - cfg - }, - None => { - error!("Failed to load upstreams configuration from: {}", self.config.upstreams_conf); - return; - } - }; - - match config.typecfg.as_str() { - "file" => { - info!("Running File discovery, requested type is: {}", config.typecfg); - tx.send(config).await.unwrap(); - let file_load = FromFileProvider { - path: self.config.upstreams_conf.clone(), - }; - let _ = tokio::spawn(async move { file_load.start(tx).await }); - } - _ => { - error!("Unknown discovery type: {}", config.typecfg); - } - } - - let api_load = APIUpstreamProvider { - address: self.config.config_address.clone(), - masterkey: self.config.master_key.clone(), - config_api_enabled: self.config.config_api_enabled.clone(), - tls_address: self.config.config_tls_address.clone(), - tls_certificate: self.config.config_tls_certificate.clone(), - tls_key_file: self.config.config_tls_key_file.clone(), - file_server_address: self.config.file_server_address.clone(), - file_server_folder: self.config.file_server_folder.clone(), - }; - let _ = tokio::spawn(async move { api_load.start(tx_api).await }); - - // Use AppConfig values as defaults - let (default_healthcheck_method, default_healthcheck_interval) = (self.config.healthcheck_method.clone(), self.config.healthcheck_interval); - let mut healthcheck_method = default_healthcheck_method.clone(); - let mut healthcheck_interval = default_healthcheck_interval; - let mut healthcheck_started = false; - - loop { - tokio::select! { - _ = shutdown.changed() => { - break; - } - val = rx.next() => { - match val { - Some(ss) => { - // Update healthcheck settings from upstreams config if available - if let Some(interval) = ss.healthcheck_interval { - healthcheck_interval = interval; - } - if let Some(method) = &ss.healthcheck_method { - healthcheck_method = method.clone(); - } - - // Start healthcheck on first config load - if !healthcheck_started { - let uu_clone = self.ump_upst.clone(); - let ff_clone = self.ump_full.clone(); - let im_clone = self.ump_byid.clone(); - let method_clone = healthcheck_method.clone(); - let interval_clone = healthcheck_interval; - let _ = tokio::spawn(async move { - healthcheck::hc2(uu_clone, ff_clone, im_clone, (&*method_clone.to_string(), interval_clone.to_string().parse().unwrap())).await - }); - healthcheck_started = true; - } - - // Update arxignis_paths (global paths that work across all hostnames) - self.arxignis_paths.clear(); - for entry in ss.arxignis_paths.iter() { - let (servers, counter) = entry.value(); - let new_counter = std::sync::atomic::AtomicUsize::new(counter.load(std::sync::atomic::Ordering::Relaxed)); - self.arxignis_paths.insert(entry.key().clone(), (servers.clone(), new_counter)); - } - - crate::utils::tools::clone_dashmap_into(&ss.upstreams, &self.ump_full); - crate::utils::tools::clone_dashmap_into(&ss.upstreams, &self.ump_upst); - let current = self.extraparams.load_full(); - let mut new = (*current).clone(); - new.sticky_sessions = ss.extraparams.sticky_sessions; - new.https_proxy_enabled = ss.extraparams.https_proxy_enabled; - new.authentication = ss.extraparams.authentication.clone(); - new.rate_limit = ss.extraparams.rate_limit; - self.extraparams.store(Arc::new(new)); - self.headers.clear(); - - for entry in ss.upstreams.iter() { - let global_key = entry.key().clone(); - let global_values = DashMap::new(); - let mut target_entry = ss.headers.entry(global_key).or_insert_with(DashMap::new); - target_entry.extend(global_values); - self.headers.insert(target_entry.key().to_owned(), target_entry.value().to_owned()); - } - - for path in ss.headers.iter() { - let path_key = path.key().clone(); - let path_headers = path.value().clone(); - self.headers.insert(path_key.clone(), path_headers); - if let Some(global_headers) = ss.headers.get("GLOBAL_HEADERS") { - if let Some(existing_headers) = self.headers.get(&path_key) { - crate::utils::tools::merge_headers(existing_headers.value(), &global_headers); - } - } - } - - // Update upstreams certificate mappings - if let Some(certs_arc) = &self.certificates { - if let Some(certs) = certs_arc.load().as_ref() { - certs.set_upstreams_cert_map(ss.certificates.clone()); - info!("Updated upstreams certificate mappings: {} entries", ss.certificates.len()); - } - } - - // Check and request certificates for new/updated domains - check_and_request_certificates_for_upstreams(&ss.upstreams, &self.config.upstreams_conf).await; - - // info!("Upstreams list is changed, updating to:"); - // print_upstreams(&self.ump_full); - } - None => {} - } - } - } - } - } -} - - -/// Check certificates for domains in upstreams and request from ACME if missing -async fn check_and_request_certificates_for_upstreams(upstreams: &crate::utils::structs::UpstreamsDashMap, upstreams_path: &str) { - // Check if ACME is enabled - let _acme_config = match get_acme_config().await { - Some(config) if config.enabled => config, - _ => { - // ACME not enabled, skip certificate checking - return; - } - }; - - // Read upstreams.yaml to check which domains need certificates - use serde_yaml; - let parsed: Option = if let Ok(yaml_content) = tokio::fs::read_to_string(upstreams_path).await { - serde_yaml::from_str(&yaml_content).ok() - } else { - warn!("Failed to read upstreams file: {}, skipping certificate checks", upstreams_path); - None - }; - - // Get Redis manager to check for existing certificates - let redis_manager = match RedisManager::get() { - Ok(rm) => rm, - Err(e) => { - warn!("Redis manager not available, skipping certificate check for upstreams: {}", e); - return; - } - }; - - let mut connection = redis_manager.get_connection(); - - // Iterate through all domains in upstreams (outer key is the hostname) - for entry in upstreams.iter() { - let domain = entry.key(); - let normalized_domain = domain.strip_prefix("*.").unwrap_or(domain); - - // Check if this domain needs a certificate by reading upstreams.yaml - let needs_cert = if let Some(ref parsed) = parsed { - if let Some(ref upstreams_map) = parsed.upstreams { - if let Some(host_config) = upstreams_map.get(domain) { - host_config.needs_certificate() - } else { - // Domain not in config, skip - continue; - } - } else { - // No upstreams in config, skip - continue; - } - } else { - // Couldn't read config, skip - continue; - }; - - if !needs_cert { - debug!("Skipping certificate check for domain {} (no ACME config and ssl_enabled: false)", domain); - continue; - } - - // Check if certificate exists in Redis - // Get prefix from RedisManager - let prefix = RedisManager::get() - .map(|rm| rm.get_prefix().to_string()) - .unwrap_or_else(|_| "ssl-storage".to_string()); - let fullchain_key = format!("{}:{}:live:fullchain", prefix, normalized_domain); - let cert_exists: u32 = match redis::cmd("EXISTS") - .arg(&fullchain_key) - .query_async(&mut connection) - .await - { - Ok(exists) => exists, - Err(e) => { - warn!("Failed to check certificate existence for domain {}: {}", domain, e); - continue; - } - }; - - // Certificate doesn't exist, request from ACME - if cert_exists == 0 { - info!("Certificate not found in Redis for domain: {}, requesting from ACME", domain); - // Use a placeholder certificate path (will be stored in Redis) - let certificate_path = format!("/tmp/synapse-certs/{}", normalized_domain.replace('.', "_")); - if let Err(e) = request_certificate_from_acme(domain, normalized_domain, &certificate_path).await { - warn!("Failed to request certificate from ACME for domain {}: {}", domain, e); - } else { - info!("Successfully requested certificate from ACME for domain: {}", domain); - } - } else { - info!("Certificate already exists in Redis for domain: {}", domain); - } - } -} +use crate::utils::discovery::{APIUpstreamProvider, Discovery, FromFileProvider}; +use crate::utils::parceyaml::load_configuration; +use crate::utils::structs::Configuration; +use crate::utils::healthcheck; +use crate::http_proxy::proxyhttp::LB; +use crate::worker::certificate::{request_certificate_from_acme, get_acme_config}; +use async_trait::async_trait; +use dashmap::DashMap; +use futures::channel::mpsc; +use futures::{SinkExt, StreamExt}; +use log::{debug, error, info, warn}; +use pingora_core::server::ShutdownWatch; +use pingora_core::services::background::BackgroundService; +use std::sync::Arc; +use crate::redis::RedisManager; + +#[async_trait] +impl BackgroundService for LB { + async fn start(&self, mut shutdown: ShutdownWatch) { + info!("Starting Pingora background service for upstreams management"); + let (mut tx, mut rx) = mpsc::channel::(1); + let tx_api = tx.clone(); + + // Skip if no upstreams config file is provided (e.g., when using new config format) + if self.config.upstreams_conf.is_empty() { + info!("No upstreams config file specified, Pingora proxy system not initialized"); + return; + } + + info!("Loading upstreams configuration from: {}", self.config.upstreams_conf); + let config = match load_configuration(self.config.upstreams_conf.clone().as_str(), "filepath").await { + Some(cfg) => { + info!("Upstreams configuration loaded successfully"); + cfg + }, + None => { + error!("Failed to load upstreams configuration from: {}", self.config.upstreams_conf); + return; + } + }; + + match config.typecfg.as_str() { + "file" => { + info!("Running File discovery, requested type is: {}", config.typecfg); + tx.send(config).await.unwrap(); + let file_load = FromFileProvider { + path: self.config.upstreams_conf.clone(), + }; + let _ = tokio::spawn(async move { file_load.start(tx).await }); + } + _ => { + error!("Unknown discovery type: {}", config.typecfg); + } + } + + let api_load = APIUpstreamProvider { + address: self.config.config_address.clone(), + masterkey: self.config.master_key.clone(), + config_api_enabled: self.config.config_api_enabled.clone(), + tls_address: self.config.config_tls_address.clone(), + tls_certificate: self.config.config_tls_certificate.clone(), + tls_key_file: self.config.config_tls_key_file.clone(), + file_server_address: self.config.file_server_address.clone(), + file_server_folder: self.config.file_server_folder.clone(), + }; + let _ = tokio::spawn(async move { api_load.start(tx_api).await }); + + // Use AppConfig values as defaults + let (default_healthcheck_method, default_healthcheck_interval) = (self.config.healthcheck_method.clone(), self.config.healthcheck_interval); + let mut healthcheck_method = default_healthcheck_method.clone(); + let mut healthcheck_interval = default_healthcheck_interval; + let mut healthcheck_started = false; + + loop { + tokio::select! { + _ = shutdown.changed() => { + break; + } + val = rx.next() => { + match val { + Some(ss) => { + // Update healthcheck settings from upstreams config if available + if let Some(interval) = ss.healthcheck_interval { + healthcheck_interval = interval; + } + if let Some(method) = &ss.healthcheck_method { + healthcheck_method = method.clone(); + } + + // Start healthcheck on first config load + if !healthcheck_started { + let uu_clone = self.ump_upst.clone(); + let ff_clone = self.ump_full.clone(); + let im_clone = self.ump_byid.clone(); + let method_clone = healthcheck_method.clone(); + let interval_clone = healthcheck_interval; + let _ = tokio::spawn(async move { + healthcheck::hc2(uu_clone, ff_clone, im_clone, (&*method_clone.to_string(), interval_clone.to_string().parse().unwrap())).await + }); + healthcheck_started = true; + } + + // Update arxignis_paths (global paths that work across all hostnames) + self.arxignis_paths.clear(); + for entry in ss.arxignis_paths.iter() { + let (servers, counter) = entry.value(); + let new_counter = std::sync::atomic::AtomicUsize::new(counter.load(std::sync::atomic::Ordering::Relaxed)); + self.arxignis_paths.insert(entry.key().clone(), (servers.clone(), new_counter)); + } + + crate::utils::tools::clone_dashmap_into(&ss.upstreams, &self.ump_full); + crate::utils::tools::clone_dashmap_into(&ss.upstreams, &self.ump_upst); + let current = self.extraparams.load_full(); + let mut new = (*current).clone(); + new.sticky_sessions = ss.extraparams.sticky_sessions; + new.https_proxy_enabled = ss.extraparams.https_proxy_enabled; + new.authentication = ss.extraparams.authentication.clone(); + new.rate_limit = ss.extraparams.rate_limit; + self.extraparams.store(Arc::new(new)); + self.headers.clear(); + + for entry in ss.upstreams.iter() { + let global_key = entry.key().clone(); + let global_values = DashMap::new(); + let mut target_entry = ss.headers.entry(global_key).or_insert_with(DashMap::new); + target_entry.extend(global_values); + self.headers.insert(target_entry.key().to_owned(), target_entry.value().to_owned()); + } + + for path in ss.headers.iter() { + let path_key = path.key().clone(); + let path_headers = path.value().clone(); + self.headers.insert(path_key.clone(), path_headers); + if let Some(global_headers) = ss.headers.get("GLOBAL_HEADERS") { + if let Some(existing_headers) = self.headers.get(&path_key) { + crate::utils::tools::merge_headers(existing_headers.value(), &global_headers); + } + } + } + + // Update upstreams certificate mappings + if let Some(certs_arc) = &self.certificates { + if let Some(certs) = certs_arc.load().as_ref() { + certs.set_upstreams_cert_map(ss.certificates.clone()); + info!("Updated upstreams certificate mappings: {} entries", ss.certificates.len()); + } + } + + // Check and request certificates for new/updated domains + check_and_request_certificates_for_upstreams(&ss.upstreams, &self.config.upstreams_conf).await; + + // info!("Upstreams list is changed, updating to:"); + // print_upstreams(&self.ump_full); + } + None => {} + } + } + } + } + } +} + + +/// Check certificates for domains in upstreams and request from ACME if missing +async fn check_and_request_certificates_for_upstreams(upstreams: &crate::utils::structs::UpstreamsDashMap, upstreams_path: &str) { + // Check if ACME is enabled + let _acme_config = match get_acme_config().await { + Some(config) if config.enabled => config, + _ => { + // ACME not enabled, skip certificate checking + return; + } + }; + + // Read upstreams.yaml to check which domains need certificates + use serde_yaml; + let parsed: Option = if let Ok(yaml_content) = tokio::fs::read_to_string(upstreams_path).await { + serde_yaml::from_str(&yaml_content).ok() + } else { + warn!("Failed to read upstreams file: {}, skipping certificate checks", upstreams_path); + None + }; + + // Get Redis manager to check for existing certificates + let redis_manager = match RedisManager::get() { + Ok(rm) => rm, + Err(e) => { + warn!("Redis manager not available, skipping certificate check for upstreams: {}", e); + return; + } + }; + + let mut connection = redis_manager.get_connection(); + + // Iterate through all domains in upstreams (outer key is the hostname) + for entry in upstreams.iter() { + let domain = entry.key(); + let normalized_domain = domain.strip_prefix("*.").unwrap_or(domain); + + // Check if this domain needs a certificate by reading upstreams.yaml + let needs_cert = if let Some(ref parsed) = parsed { + if let Some(ref upstreams_map) = parsed.upstreams { + if let Some(host_config) = upstreams_map.get(domain) { + host_config.needs_certificate() + } else { + // Domain not in config, skip + continue; + } + } else { + // No upstreams in config, skip + continue; + } + } else { + // Couldn't read config, skip + continue; + }; + + if !needs_cert { + debug!("Skipping certificate check for domain {} (no ACME config and ssl_enabled: false)", domain); + continue; + } + + // Check if certificate exists in Redis + // Get prefix from RedisManager + let prefix = RedisManager::get() + .map(|rm| rm.get_prefix().to_string()) + .unwrap_or_else(|_| "ssl-storage".to_string()); + let fullchain_key = format!("{}:{}:live:fullchain", prefix, normalized_domain); + let cert_exists: u32 = match redis::cmd("EXISTS") + .arg(&fullchain_key) + .query_async(&mut connection) + .await + { + Ok(exists) => exists, + Err(e) => { + warn!("Failed to check certificate existence for domain {}: {}", domain, e); + continue; + } + }; + + // Certificate doesn't exist, request from ACME + if cert_exists == 0 { + info!("Certificate not found in Redis for domain: {}, requesting from ACME", domain); + // Use a placeholder certificate path (will be stored in Redis) + let certificate_path = format!("/tmp/synapse-certs/{}", normalized_domain.replace('.', "_")); + if let Err(e) = request_certificate_from_acme(domain, normalized_domain, &certificate_path).await { + warn!("Failed to request certificate from ACME for domain {}: {}", domain, e); + } else { + info!("Successfully requested certificate from ACME for domain: {}", domain); + } + } else { + info!("Certificate already exists in Redis for domain: {}", domain); + } + } +} diff --git a/src/http_proxy/gethosts.rs b/src/http_proxy/gethosts.rs index 1ef3060..ceed158 100644 --- a/src/http_proxy/gethosts.rs +++ b/src/http_proxy/gethosts.rs @@ -1,156 +1,156 @@ -use crate::utils::structs::InnerMap; -use crate::http_proxy::proxyhttp::LB; -use async_trait::async_trait; -use std::sync::atomic::Ordering; -use log::debug; - -#[async_trait] -pub trait GetHost { - fn get_host(&self, peer: &str, path: &str, backend_id: Option<&str>) -> Option; - fn get_header(&self, peer: &str, path: &str) -> Option>; -} -#[async_trait] -impl GetHost for LB { - fn get_host(&self, peer: &str, path: &str, backend_id: Option<&str>) -> Option { - if let Some(b) = backend_id { - if let Some(bb) = self.ump_byid.get(b) { - // println!("BIB :===> {:?}", Some(bb.value())); - return Some(bb.value().clone()); - } - } - - // Check arxignis_paths first - these paths work regardless of hostname - // Try exact match first - if let Some(arxignis_path_entry) = self.arxignis_paths.get(path) { - let (servers, index) = arxignis_path_entry.value(); - if !servers.is_empty() { - let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); - debug!("Using Gen0Sec path {} -> {}", path, servers[idx].address); - return Some(servers[idx].clone()); - } - } - // If no exact match, try prefix/wildcard matching - check if any configured path is a prefix of the request path - // Collect all matches and use the longest one (most specific match) - let mut best_match: Option<(String, InnerMap, usize)> = None; - for entry in self.arxignis_paths.iter() { - let pattern = entry.key(); - // Handle wildcard patterns ending with /* - strip the /* for matching - let (pattern_prefix, is_wildcard) = if pattern.ends_with("/*") { - (pattern.strip_suffix("/*").unwrap_or(pattern.as_str()), true) - } else { - (pattern.as_str(), false) - }; - - // Check if the request path starts with the pattern prefix (prefix match) - if path.starts_with(pattern_prefix) { - // For wildcard patterns (ending with /*), match any path that starts with the prefix - // For non-wildcard patterns, ensure it's a proper path segment match - let is_valid_match = if is_wildcard { - // Wildcard pattern: match if path starts with prefix (already checked above) - true - } else if pattern_prefix.ends_with('/') { - // Pattern ends with /, so it matches any path starting with it - true - } else if path.len() == pattern_prefix.len() { - // Exact match (already handled above, but keep for completeness) - true - } else if let Some(next_char) = path.chars().skip(pattern_prefix.len()).next() { - // Next character after prefix should be / for proper path segment match - next_char == '/' - } else { - false - }; - - if is_valid_match { - let (servers, index) = entry.value(); - if !servers.is_empty() { - let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); - let matched_server = servers[idx].clone(); - let prefix_len = pattern_prefix.len(); - // Keep the longest (most specific) match based on the prefix length - if best_match.as_ref().map_or(true, |(_, _, best_len)| prefix_len > *best_len) { - best_match = Some((pattern.clone(), matched_server, prefix_len)); - } - } - } - } - } - if let Some((pattern, server, _)) = best_match { - debug!("Using Gen0Sec path pattern {} -> {} (matched path: {})", pattern, server.address, path); - return Some(server); - } - // If no prefix match, try progressively shorter paths (same logic as regular upstreams) - let mut current_path = path.to_string(); - loop { - if let Some(arxignis_path_entry) = self.arxignis_paths.get(¤t_path) { - let (servers, index) = arxignis_path_entry.value(); - if !servers.is_empty() { - let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); - debug!("Using Gen0Sec path {} -> {} (matched from {})", current_path, servers[idx].address, path); - return Some(servers[idx].clone()); - } - } - if let Some(pos) = current_path.rfind('/') { - current_path.truncate(pos); - } else { - break; - } - } - - let host_entry = self.ump_upst.get(peer)?; - let mut current_path = path.to_string(); - let mut best_match: Option = None; - loop { - if let Some(entry) = host_entry.get(¤t_path) { - let (servers, index) = entry.value(); - if !servers.is_empty() { - let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); - best_match = Some(servers[idx].clone()); - break; - } - } - if let Some(pos) = current_path.rfind('/') { - current_path.truncate(pos); - } else { - break; - } - } - if best_match.is_none() { - if let Some(entry) = host_entry.get("/") { - let (servers, index) = entry.value(); - if !servers.is_empty() { - let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); - best_match = Some(servers[idx].clone()); - } - } - } - // println!("Best Match :===> {:?}", best_match); - best_match - } - fn get_header(&self, peer: &str, path: &str) -> Option> { - let host_entry = self.headers.get(peer)?; - let mut current_path = path.to_string(); - let mut best_match: Option> = None; - loop { - if let Some(entry) = host_entry.get(¤t_path) { - if !entry.value().is_empty() { - best_match = Some(entry.value().clone()); - break; - } - } - if let Some(pos) = current_path.rfind('/') { - current_path.truncate(pos); - } else { - break; - } - } - if best_match.is_none() { - if let Some(entry) = host_entry.get("/") { - if !entry.value().is_empty() { - best_match = Some(entry.value().clone()); - } - } - } - best_match - } -} +use crate::utils::structs::InnerMap; +use crate::http_proxy::proxyhttp::LB; +use async_trait::async_trait; +use std::sync::atomic::Ordering; +use log::debug; + +#[async_trait] +pub trait GetHost { + fn get_host(&self, peer: &str, path: &str, backend_id: Option<&str>) -> Option; + fn get_header(&self, peer: &str, path: &str) -> Option>; +} +#[async_trait] +impl GetHost for LB { + fn get_host(&self, peer: &str, path: &str, backend_id: Option<&str>) -> Option { + if let Some(b) = backend_id { + if let Some(bb) = self.ump_byid.get(b) { + // println!("BIB :===> {:?}", Some(bb.value())); + return Some(bb.value().clone()); + } + } + + // Check arxignis_paths first - these paths work regardless of hostname + // Try exact match first + if let Some(arxignis_path_entry) = self.arxignis_paths.get(path) { + let (servers, index) = arxignis_path_entry.value(); + if !servers.is_empty() { + let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); + debug!("Using Gen0Sec path {} -> {}", path, servers[idx].address); + return Some(servers[idx].clone()); + } + } + // If no exact match, try prefix/wildcard matching - check if any configured path is a prefix of the request path + // Collect all matches and use the longest one (most specific match) + let mut best_match: Option<(String, InnerMap, usize)> = None; + for entry in self.arxignis_paths.iter() { + let pattern = entry.key(); + // Handle wildcard patterns ending with /* - strip the /* for matching + let (pattern_prefix, is_wildcard) = if pattern.ends_with("/*") { + (pattern.strip_suffix("/*").unwrap_or(pattern.as_str()), true) + } else { + (pattern.as_str(), false) + }; + + // Check if the request path starts with the pattern prefix (prefix match) + if path.starts_with(pattern_prefix) { + // For wildcard patterns (ending with /*), match any path that starts with the prefix + // For non-wildcard patterns, ensure it's a proper path segment match + let is_valid_match = if is_wildcard { + // Wildcard pattern: match if path starts with prefix (already checked above) + true + } else if pattern_prefix.ends_with('/') { + // Pattern ends with /, so it matches any path starting with it + true + } else if path.len() == pattern_prefix.len() { + // Exact match (already handled above, but keep for completeness) + true + } else if let Some(next_char) = path.chars().skip(pattern_prefix.len()).next() { + // Next character after prefix should be / for proper path segment match + next_char == '/' + } else { + false + }; + + if is_valid_match { + let (servers, index) = entry.value(); + if !servers.is_empty() { + let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); + let matched_server = servers[idx].clone(); + let prefix_len = pattern_prefix.len(); + // Keep the longest (most specific) match based on the prefix length + if best_match.as_ref().map_or(true, |(_, _, best_len)| prefix_len > *best_len) { + best_match = Some((pattern.clone(), matched_server, prefix_len)); + } + } + } + } + } + if let Some((pattern, server, _)) = best_match { + debug!("Using Gen0Sec path pattern {} -> {} (matched path: {})", pattern, server.address, path); + return Some(server); + } + // If no prefix match, try progressively shorter paths (same logic as regular upstreams) + let mut current_path = path.to_string(); + loop { + if let Some(arxignis_path_entry) = self.arxignis_paths.get(¤t_path) { + let (servers, index) = arxignis_path_entry.value(); + if !servers.is_empty() { + let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); + debug!("Using Gen0Sec path {} -> {} (matched from {})", current_path, servers[idx].address, path); + return Some(servers[idx].clone()); + } + } + if let Some(pos) = current_path.rfind('/') { + current_path.truncate(pos); + } else { + break; + } + } + + let host_entry = self.ump_upst.get(peer)?; + let mut current_path = path.to_string(); + let mut best_match: Option = None; + loop { + if let Some(entry) = host_entry.get(¤t_path) { + let (servers, index) = entry.value(); + if !servers.is_empty() { + let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); + best_match = Some(servers[idx].clone()); + break; + } + } + if let Some(pos) = current_path.rfind('/') { + current_path.truncate(pos); + } else { + break; + } + } + if best_match.is_none() { + if let Some(entry) = host_entry.get("/") { + let (servers, index) = entry.value(); + if !servers.is_empty() { + let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); + best_match = Some(servers[idx].clone()); + } + } + } + // println!("Best Match :===> {:?}", best_match); + best_match + } + fn get_header(&self, peer: &str, path: &str) -> Option> { + let host_entry = self.headers.get(peer)?; + let mut current_path = path.to_string(); + let mut best_match: Option> = None; + loop { + if let Some(entry) = host_entry.get(¤t_path) { + if !entry.value().is_empty() { + best_match = Some(entry.value().clone()); + break; + } + } + if let Some(pos) = current_path.rfind('/') { + current_path.truncate(pos); + } else { + break; + } + } + if best_match.is_none() { + if let Some(entry) = host_entry.get("/") { + if !entry.value().is_empty() { + best_match = Some(entry.value().clone()); + } + } + } + best_match + } +} diff --git a/src/http_proxy/proxyhttp.rs b/src/http_proxy/proxyhttp.rs index f2470a0..95f8495 100644 --- a/src/http_proxy/proxyhttp.rs +++ b/src/http_proxy/proxyhttp.rs @@ -1,1026 +1,1026 @@ -use crate::utils::structs::{AppConfig, Extraparams, Headers, InnerMap, UpstreamsDashMap, UpstreamsIdMap}; -use crate::http_proxy::gethosts::GetHost; -use crate::waf::wirefilter::{evaluate_waf_for_pingora_request, WafAction}; -use crate::waf::actions::captcha::{validate_captcha_token, apply_captcha_challenge_with_token, generate_captcha_token}; -use arc_swap::ArcSwap; -use async_trait::async_trait; -use axum::body::Bytes; -use dashmap::DashMap; -use log::{debug, error, info, warn}; -use once_cell::sync::Lazy; -use pingora_http::{RequestHeader, ResponseHeader, StatusCode}; -use pingora_core::prelude::*; -use pingora_core::ErrorSource::{Upstream, Internal as ErrorSourceInternal}; -use pingora_core::{Error, ErrorType::HTTPStatus, RetryType, ImmutStr}; -use pingora_core::listeners::ALPN; -use pingora_core::prelude::HttpPeer; -use pingora_limits::rate::Rate; -use pingora_proxy::{ProxyHttp, Session}; -use serde_json; -use std::sync::Arc; -use std::sync::atomic::AtomicUsize; -use std::time::Duration; -use tokio::time::Instant; -use hyper::http; - -static RATE_LIMITER: Lazy = Lazy::new(|| Rate::new(Duration::from_secs(1))); -static WAF_RATE_LIMITERS: Lazy>> = Lazy::new(|| DashMap::new()); - -#[derive(Clone)] -pub struct LB { - pub ump_upst: Arc, - pub ump_full: Arc, - pub ump_byid: Arc, - pub arxignis_paths: Arc, AtomicUsize)>>, - pub headers: Arc, - pub config: Arc, - pub extraparams: Arc>, - pub tcp_fingerprint_collector: Option>, - pub certificates: Option>>>>, -} - -pub struct Context { - backend_id: String, - start_time: Instant, - upstream_start_time: Option, - hostname: Option, - upstream_peer: Option, - extraparams: arc_swap::Guard>, - tls_fingerprint: Option>, - request_body: Vec, - malware_detected: bool, - malware_response_sent: bool, - waf_result: Option, - threat_data: Option, - upstream_time: Option, - disable_access_log: bool, -} - -#[async_trait] -impl ProxyHttp for LB { - type CTX = Context; - fn new_ctx(&self) -> Self::CTX { - Context { - backend_id: String::new(), - start_time: Instant::now(), - upstream_start_time: None, - hostname: None, - upstream_peer: None, - extraparams: self.extraparams.load(), - tls_fingerprint: None, - request_body: Vec::new(), - malware_detected: false, - malware_response_sent: false, - waf_result: None, - threat_data: None, - upstream_time: None, - disable_access_log: false, - } - } - async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result { - // Enable body buffering for content scanning - session.enable_retry_buffering(); - - let ep = _ctx.extraparams.clone(); - - // Userland access rules check (fallback when eBPF/XDP is not available) - // Check if IP is blocked by access rules - if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { - let client_ip: std::net::IpAddr = peer_addr.ip().into(); - - // Check if IP is blocked - if crate::access_rules::is_ip_blocked_by_access_rules(client_ip) { - log::info!("Userland access rules: Blocked request from IP: {} (matched block rule)", client_ip); - let mut header = ResponseHeader::build(403, None).unwrap(); - header.insert_header("X-Block-Reason", "access_rules").ok(); - session.set_keepalive(None); - session.write_response_header(Box::new(header), true).await?; - return Ok(true); - } - } - - // Try to get TLS fingerprint if available - // Use fallback lookup to handle PROXY protocol address mismatches - if _ctx.tls_fingerprint.is_none() { - if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { - let std_addr = std::net::SocketAddr::new(peer_addr.ip().into(), peer_addr.port()); - if let Some(fingerprint) = crate::utils::tls_client_hello::get_fingerprint_with_fallback(&std_addr) { - _ctx.tls_fingerprint = Some(fingerprint.clone()); - debug!( - "TLS Fingerprint retrieved for session - Peer: {}, JA4: {}, SNI: {:?}, ALPN: {:?}", - std_addr, - fingerprint.ja4, - fingerprint.sni, - fingerprint.alpn - ); - } else { - debug!("No TLS fingerprint found in storage for peer: {} (PROXY protocol may cause this)", std_addr); - } - } - } - - // Get threat intelligence data BEFORE WAF evaluation - // This ensures threat intelligence is available in access logs even when WAF blocks/challenges early - if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { - if _ctx.threat_data.is_none() { - match crate::threat::get_threat_intel(&peer_addr.ip().to_string()).await { - Ok(Some(threat_response)) => { - _ctx.threat_data = Some(threat_response); - debug!("Threat intelligence retrieved for IP: {}", peer_addr.ip()); - } - Ok(None) => { - debug!("No threat intelligence data for IP: {}", peer_addr.ip()); - } - Err(e) => { - debug!("Threat intelligence error for IP {}: {}", peer_addr.ip(), e); - } - } - } - } - - // Evaluate WAF rules - if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { - let socket_addr = std::net::SocketAddr::new(peer_addr.ip(), peer_addr.port()); - match evaluate_waf_for_pingora_request(session.req_header(), b"", socket_addr).await { - Ok(Some(waf_result)) => { - debug!("WAF rule matched: rule={}, id={}, action={:?}", waf_result.rule_name, waf_result.rule_id, waf_result.action); - - // Store threat response from WAF result if available (WAF already fetched it) - if let Some(threat_resp) = waf_result.threat_response.clone() { - _ctx.threat_data = Some(threat_resp); - debug!("Threat intelligence retrieved from WAF evaluation for IP: {}", peer_addr.ip()); - } - - // Store WAF result in context for access logging - _ctx.waf_result = Some(waf_result.clone()); - - match waf_result.action { - WafAction::Block => { - info!("WAF blocked request: rule={}, id={}, uri={}", waf_result.rule_name, waf_result.rule_id, session.req_header().uri); - let mut header = ResponseHeader::build(403, None).unwrap(); - header.insert_header("X-WAF-Rule", waf_result.rule_name).ok(); - header.insert_header("X-WAF-Rule-ID", waf_result.rule_id).ok(); - session.set_keepalive(None); - session.write_response_header(Box::new(header), true).await?; - return Ok(true); - } - WafAction::Challenge => { - info!("WAF challenge required: rule={}, id={}, uri={}", waf_result.rule_name, waf_result.rule_id, session.req_header().uri); - - // Check for captcha token in cookies or headers - let mut captcha_token: Option = None; - - // Check cookies for captcha_token - if let Some(cookies) = session.req_header().headers.get("cookie") { - if let Ok(cookie_str) = cookies.to_str() { - for cookie in cookie_str.split(';') { - let trimmed = cookie.trim(); - if let Some(value) = trimmed.strip_prefix("captcha_token=") { - captcha_token = Some(value.to_string()); - break; - } - } - } - } - - // Check X-Captcha-Token header if not found in cookies - if captcha_token.is_none() { - if let Some(token_header) = session.req_header().headers.get("x-captcha-token") { - if let Ok(token_str) = token_header.to_str() { - captcha_token = Some(token_str.to_string()); - } - } - } - - // Validate token if present - let token_valid = if let Some(token) = &captcha_token { - let user_agent = session.req_header().headers - .get("user-agent") - .and_then(|h| h.to_str().ok()) - .unwrap_or("") - .to_string(); - - match validate_captcha_token(token, &peer_addr.ip().to_string(), &user_agent).await { - Ok(valid) => { - if valid { - debug!("Captcha token validated successfully"); - } else { - debug!("Captcha token validation failed"); - } - valid - } - Err(e) => { - error!("Captcha token validation error: {}", e); - false - } - } - } else { - false - }; - - if !token_valid { - // Generate a new token (don't reuse invalid token) - let jwt_token = { - let user_agent = session.req_header().headers - .get("user-agent") - .and_then(|h| h.to_str().ok()) - .unwrap_or("") - .to_string(); - - match generate_captcha_token( - peer_addr.ip().to_string(), - user_agent, - None, // JA4 fingerprint not available here - ).await { - Ok(token) => token.token, - Err(e) => { - error!("Failed to generate captcha token: {}", e); - // Fallback to challenge without token - match apply_captcha_challenge_with_token("") { - Ok(html) => { - let mut header = ResponseHeader::build(403, None).unwrap(); - header.insert_header("Content-Type", "text/html; charset=utf-8").ok(); - session.set_keepalive(None); - session.write_response_header(Box::new(header), false).await?; - session.write_response_body(Some(Bytes::from(html)), true).await?; - return Ok(true); - } - Err(e) => { - error!("Failed to apply captcha challenge: {}", e); - // Block the request if captcha fails - let mut header = ResponseHeader::build(403, None).unwrap(); - header.insert_header("X-WAF-Rule", waf_result.rule_name).ok(); - header.insert_header("X-WAF-Rule-ID", waf_result.rule_id).ok(); - session.set_keepalive(None); - session.write_response_header(Box::new(header), true).await?; - return Ok(true); - } - } - } - } - }; - - // Return captcha challenge page - match apply_captcha_challenge_with_token(&jwt_token) { - Ok(html) => { - let mut header = ResponseHeader::build(403, None).unwrap(); - header.insert_header("Content-Type", "text/html; charset=utf-8").ok(); - header.insert_header("Set-Cookie", format!("captcha_token={}; Path=/; HttpOnly; SameSite=Lax", jwt_token)).ok(); - header.insert_header("X-WAF-Rule", waf_result.rule_name).ok(); - header.insert_header("X-WAF-Rule-ID", waf_result.rule_id).ok(); - session.set_keepalive(None); - session.write_response_header(Box::new(header), false).await?; - session.write_response_body(Some(Bytes::from(html)), true).await?; - return Ok(true); - } - Err(e) => { - error!("Failed to apply captcha challenge: {}", e); - // Block the request if captcha fails - let mut header = ResponseHeader::build(403, None).unwrap(); - header.insert_header("X-WAF-Rule", waf_result.rule_name).ok(); - header.insert_header("X-WAF-Rule-ID", waf_result.rule_id).ok(); - session.set_keepalive(None); - session.write_response_header(Box::new(header), true).await?; - return Ok(true); - } - } - } else { - // Token is valid, allow request to continue - debug!("Captcha token validated, allowing request"); - } - } - WafAction::RateLimit => { - // Get rate limit config from waf_result - if let Some(rate_limit_config) = &waf_result.rate_limit_config { - let period_secs = rate_limit_config.period_secs(); - let requests_limit = rate_limit_config.requests_count(); - - // Get or create rate limiter for this rule - let rate_limiter = WAF_RATE_LIMITERS - .entry(waf_result.rule_id.clone()) - .or_insert_with(|| { - debug!("Creating new rate limiter for rule {}: {} requests per {} seconds", - waf_result.rule_id, requests_limit, period_secs); - Arc::new(Rate::new(Duration::from_secs(period_secs))) - }) - .clone(); - - // Use client IP as the rate key - let rate_key = peer_addr.ip().to_string(); - let curr_window_requests = rate_limiter.observe(&rate_key, 1); - - if curr_window_requests > requests_limit as isize { - info!("Rate limit exceeded: rule={}, id={}, ip={}, requests={}/{}", - waf_result.rule_name, waf_result.rule_id, rate_key, curr_window_requests, requests_limit); - - let body = serde_json::json!({ - "error": "Too Many Requests", - "message": format!("Rate limit exceeded: {} requests per {} seconds", requests_limit, period_secs), - "rule": &waf_result.rule_name, - "rule_id": &waf_result.rule_id - }).to_string(); - - let mut header = ResponseHeader::build(429, None).unwrap(); - header.insert_header("X-Rate-Limit-Limit", requests_limit.to_string()).ok(); - header.insert_header("X-Rate-Limit-Remaining", "0").ok(); - header.insert_header("X-Rate-Limit-Reset", period_secs.to_string()).ok(); - header.insert_header("X-WAF-Rule", &waf_result.rule_name).ok(); - header.insert_header("X-WAF-Rule-ID", &waf_result.rule_id).ok(); - header.insert_header("Content-Type", "application/json").ok(); - - session.set_keepalive(None); - session.write_response_header(Box::new(header), false).await?; - session.write_response_body(Some(Bytes::from(body)), true).await?; - return Ok(true); - } else { - debug!("Rate limit check passed: rule={}, id={}, ip={}, requests={}/{}", - waf_result.rule_name, waf_result.rule_id, rate_key, curr_window_requests, requests_limit); - } - } else { - warn!("Rate limit action triggered but no config found for rule {}", waf_result.rule_id); - } - } - WafAction::Allow => { - debug!("WAF allowed request: rule={}, id={}", waf_result.rule_name, waf_result.rule_id); - // Allow the request to continue - } - } - } - Ok(None) => { - // No WAF rules matched, allow request to continue - debug!("WAF: No rules matched for uri={}", session.req_header().uri); - } - Err(e) => { - error!("WAF evaluation error: {}", e); - // On error, allow request to continue (fail open) - } - } - } else { - debug!("WAF: No peer address available for request"); - } - - let hostname = return_header_host(&session); - _ctx.hostname = hostname; - - let mut backend_id = None; - - if ep.sticky_sessions { - if let Some(cookies) = session.req_header().headers.get("cookie") { - if let Ok(cookie_str) = cookies.to_str() { - for cookie in cookie_str.split(';') { - let trimmed = cookie.trim(); - if let Some(value) = trimmed.strip_prefix("backend_id=") { - backend_id = Some(value); - break; - } - } - } - } - } - - match _ctx.hostname.as_ref() { - None => return Ok(false), - Some(host) => { - // let optioninnermap = self.get_host(host.as_str(), host.as_str(), backend_id); - let optioninnermap = self.get_host(host.as_str(), session.req_header().uri.path(), backend_id); - match optioninnermap { - None => return Ok(false), - Some(ref innermap) => { - // Check for HTTPS redirect before rate limiting - if ep.https_proxy_enabled.unwrap_or(false) || innermap.https_proxy_enabled { - if let Some(stream) = session.stream() { - if stream.get_ssl().is_none() { - // HTTP request - redirect to HTTPS - let uri = session.req_header().uri.path_and_query().map_or("/", |pq| pq.as_str()); - let port = self.config.proxy_port_tls.unwrap_or(403); - let redirect_url = format!("https://{}:{}{}", host, port, uri); - let mut redirect_response = ResponseHeader::build(StatusCode::MOVED_PERMANENTLY, None)?; - redirect_response.insert_header("Location", redirect_url)?; - redirect_response.insert_header("Content-Length", "0")?; - session.set_keepalive(None); - session.write_response_header(Box::new(redirect_response), false).await?; - return Ok(true); - } - } - } - if let Some(rate) = innermap.rate_limit.or(ep.rate_limit) { - // let rate_key = session.client_addr().and_then(|addr| addr.as_inet()).map(|inet| inet.ip().to_string()).unwrap_or_else(|| host.to_string()); - let rate_key = session.client_addr().and_then(|addr| addr.as_inet()).map(|inet| inet.ip()); - let curr_window_requests = RATE_LIMITER.observe(&rate_key, 1); - if curr_window_requests > rate { - let mut header = ResponseHeader::build(429, None).unwrap(); - header.insert_header("X-Rate-Limit-Limit", rate.to_string()).unwrap(); - header.insert_header("X-Rate-Limit-Remaining", "0").unwrap(); - header.insert_header("X-Rate-Limit-Reset", "1").unwrap(); - session.set_keepalive(None); - session.write_response_header(Box::new(header), true).await?; - debug!("Rate limited: {:?}, {}", rate_key, rate); - return Ok(true); - } - } - } - } - _ctx.upstream_peer = optioninnermap.clone(); - // Set disable_access_log flag from upstream config - if let Some(ref innermap) = optioninnermap { - _ctx.disable_access_log = innermap.disable_access_log; - } - } - } - Ok(false) - } - async fn upstream_peer(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result> { - // Check if malware was detected, send JSON response and prevent forwarding - if ctx.malware_detected && !ctx.malware_response_sent { - // Check if response has already been written - if session.response_written().is_some() { - warn!("Response already written, cannot block malware request in upstream_peer"); - ctx.malware_response_sent = true; - return Err(Box::new(Error { - etype: HTTPStatus(403), - esource: Upstream, - retry: RetryType::Decided(false), - cause: None, - context: Option::from(ImmutStr::Static("Malware detected")), - })); - } - - info!("Blocking request due to malware detection"); - - // Build JSON response - let json_response = serde_json::json!({ - "success": false, - "error": "Request blocked", - "reason": "malware_detected", - "message": "Malware detected in request" - }); - let json_body = Bytes::from(json_response.to_string()); - - // Build response header - let mut header = ResponseHeader::build(403, None).unwrap(); - header.insert_header("Content-Type", "application/json").ok(); - header.insert_header("X-Content-Scan-Result", "malware_detected").ok(); - - session.set_keepalive(None); - - // Try to write response, handle error if response already sent - match session.write_response_header(Box::new(header), false).await { - Ok(_) => { - match session.write_response_body(Some(json_body), true).await { - Ok(_) => { - ctx.malware_response_sent = true; - } - Err(e) => { - warn!("Failed to write response body for malware block in upstream_peer: {}", e); - ctx.malware_response_sent = true; - } - } - } - Err(e) => { - warn!("Failed to write response header for malware block in upstream_peer: {}", e); - ctx.malware_response_sent = true; - } - } - - return Err(Box::new(Error { - etype: HTTPStatus(403), - esource: Upstream, - retry: RetryType::Decided(false), - cause: None, - context: Option::from(ImmutStr::Static("Malware detected")), - })); - } - - // let host_name = return_header_host(&session); - match ctx.hostname.as_ref() { - Some(_hostname) => { - match ctx.upstream_peer.as_ref() { - // Some((address, port, ssl, is_h2, https_proxy_enabled)) => { - Some(innermap) => { - let mut peer = Box::new(HttpPeer::new((innermap.address.clone(), innermap.port.clone()), innermap.ssl_enabled, String::new())); - // if session.is_http2() { - if innermap.http2_enabled { - peer.options.alpn = ALPN::H2; - } - if innermap.ssl_enabled { - // Use upstream server address for SNI, not the incoming hostname - // This ensures the SNI matches what the upstream server expects - peer.sni = innermap.address.clone(); - peer.options.verify_cert = false; - peer.options.verify_hostname = false; - } - - ctx.backend_id = format!("{}:{}:{}", innermap.address.clone(), innermap.port.clone(), innermap.ssl_enabled); - Ok(peer) - } - None => { - if let Err(e) = session.respond_error_with_body(502, Bytes::from("502 Bad Gateway\n")).await { - error!("Failed to send error response: {:?}", e); - } - Err(Box::new(Error { - etype: HTTPStatus(502), - esource: Upstream, - retry: RetryType::Decided(false), - cause: None, - context: Option::from(ImmutStr::Static("Upstream not found")), - })) - } - } - } - None => { - // session.respond_error_with_body(502, Bytes::from("502 Bad Gateway\n")).await.expect("Failed to send error"); - if let Err(e) = session.respond_error_with_body(502, Bytes::from("502 Bad Gateway\n")).await { - error!("Failed to send error response: {:?}", e); - } - Err(Box::new(Error { - etype: HTTPStatus(502), - esource: Upstream, - retry: RetryType::Decided(false), - cause: None, - context: None, - })) - } - } - } - - async fn upstream_request_filter(&self, _session: &mut Session, upstream_request: &mut RequestHeader, ctx: &mut Self::CTX) -> Result<()> { - // Track when we start upstream request - ctx.upstream_start_time = Some(Instant::now()); - - // Check if config has a Host header before setting default - let mut config_has_host = false; - if let Some(hostname) = ctx.hostname.as_ref() { - let path = _session.req_header().uri.path(); - if let Some(configured_headers) = self.get_header(hostname, path) { - for (key, _) in configured_headers.iter() { - if key.eq_ignore_ascii_case("Host") { - config_has_host = true; - break; - } - } - } - } - - // Only set default Host if config doesn't override it - if !config_has_host { - if let Some(hostname) = ctx.hostname.as_ref() { - upstream_request.insert_header("Host", hostname)?; - } - } - - if let Some(peer) = ctx.upstream_peer.as_ref() { - upstream_request.insert_header("X-Forwarded-For", peer.address.as_str())?; - } - - // Apply configured headers from upstreams.yaml (will override default Host if present) - if let Some(hostname) = ctx.hostname.as_ref() { - let path = _session.req_header().uri.path(); - if let Some(configured_headers) = self.get_header(hostname, path) { - for (key, value) in configured_headers { - // insert_header will override existing headers with the same name - let key_clone = key.clone(); - let value_clone = value.clone(); - if let Err(e) = upstream_request.insert_header(key_clone, value_clone) { - debug!("Failed to insert header {}: {}", key, e); - } - } - } - } - - Ok(()) - } - - - async fn request_body_filter(&self, _session: &mut Session, body: &mut Option, end_of_stream: bool, ctx: &mut Self::CTX) -> Result<()> - where - Self::CTX: Send + Sync, - { - // Accumulate request body for content scanning - // Copy the body data but don't take it - Pingora will forward it if no malware - if let Some(body_bytes) = body { - info!("BODY CHUNK received: {} bytes, total so far: {}, end_of_stream: {}", body_bytes.len(), ctx.request_body.len() + body_bytes.len(), end_of_stream); - ctx.request_body.extend_from_slice(body_bytes); - } - - if end_of_stream && !ctx.request_body.is_empty() { - if let Some(scanner) = crate::content_scanning::get_global_content_scanner() { - // Get peer address for scanning - let peer_addr = if let Some(addr) = _session.client_addr().and_then(|a| a.as_inet()) { - std::net::SocketAddr::new(addr.ip(), addr.port()) - } else { - return Ok(()); // Can't scan without peer address - }; - - // Convert request header to Parts for should_scan check - let req_header = _session.req_header(); - let method = req_header.method.as_str(); - let uri = req_header.uri.to_string(); - let mut req_builder = hyper::http::Request::builder() - .method(method) - .uri(&uri); - - // Copy essential headers for content scanning (content-type, content-length) - if let Some(content_type) = req_header.headers.get("content-type") { - if let Ok(ct_str) = content_type.to_str() { - req_builder = req_builder.header("content-type", ct_str); - } - } - if let Some(content_length) = req_header.headers.get("content-length") { - if let Ok(cl_str) = content_length.to_str() { - req_builder = req_builder.header("content-length", cl_str); - } - } - - let req = match req_builder.body(()) { - Ok(req) => req, - Err(_) => { - warn!("Failed to build request for content scanning, skipping scan"); - return Ok(()); - } - }; - let (req_parts, _) = req.into_parts(); - - // Check if we should scan this request - info!("Content scanner: checking if should scan - body size: {}, method: {}, content-type: {:?}", - ctx.request_body.len(), req_parts.method, req_parts.headers.get("content-type")); - let should_scan = scanner.should_scan(&req_parts, &ctx.request_body, peer_addr); - if should_scan { - info!("Content scanner: WILL SCAN request body (size: {} bytes)", ctx.request_body.len()); - - // Check if content-type is multipart and scan accordingly - let content_type = req_parts.headers - .get("content-type") - .and_then(|h| h.to_str().ok()); - - let scan_result = if let Some(ct) = content_type { - info!("Content-Type header: {}", ct); - if let Some(boundary) = crate::content_scanning::extract_multipart_boundary(ct) { - info!("Detected multipart content with boundary: '{}', scanning parts individually", boundary); - scanner.scan_multipart_content(&ctx.request_body, &boundary).await - } else { - info!("Not multipart or no boundary found, scanning as single blob"); - scanner.scan_content(&ctx.request_body).await - } - } else { - info!("No Content-Type header, scanning as single blob"); - scanner.scan_content(&ctx.request_body).await - }; - - match scan_result { - Ok(scan_result) => { - if scan_result.malware_detected { - info!("Malware detected in request from {}: {} {} - signature: {:?}", - peer_addr, method, uri, scan_result.signature); - - // Mark malware detected in context - ctx.malware_detected = true; - - // Send 403 response immediately to block the request - let json_response = serde_json::json!({ - "success": false, - "error": "Request blocked", - "reason": "malware_detected", - "message": "Malware detected in request" - }); - let json_body = Bytes::from(json_response.to_string()); - - let mut header = ResponseHeader::build(403, None)?; - header.insert_header("Content-Type", "application/json")?; - header.insert_header("X-Content-Scan-Result", "malware_detected")?; - - _session.set_keepalive(None); - _session.write_response_header(Box::new(header), false).await?; - _session.write_response_body(Some(json_body), true).await?; - - ctx.malware_response_sent = true; - - // Return error to abort the request - return Err(Box::new(Error { - etype: HTTPStatus(403), - esource: ErrorSourceInternal, - retry: RetryType::Decided(false), - cause: None, - context: Option::from(ImmutStr::Static("Malware detected")), - })); - } else { - debug!("Content scan completed: no malware detected"); - } - } - Err(e) => { - warn!("Content scanning failed: {}", e); - // On scanning error, allow the request to proceed (fail open) - } - } - } else { - debug!("Content scanner: skipping scan (should_scan returned false)"); - } - } - } - - Ok(()) - } - - async fn response_filter(&self, session: &mut Session, _upstream_response: &mut ResponseHeader, ctx: &mut Self::CTX) -> Result<()> { - // Calculate upstream response time - if let Some(upstream_start) = ctx.upstream_start_time { - ctx.upstream_time = Some(upstream_start.elapsed()); - } - - // _upstream_response.insert_header("X-Proxied-From", "Fooooooooooooooo").unwrap(); - if ctx.extraparams.sticky_sessions { - let backend_id = ctx.backend_id.clone(); - if let Some(bid) = self.ump_byid.get(&backend_id) { - let _ = _upstream_response.insert_header("set-cookie", format!("backend_id={}; Path=/; Max-Age=600; HttpOnly; SameSite=Lax", bid.address)); - } - } - match ctx.hostname.as_ref() { - Some(host) => { - let path = session.req_header().uri.path(); - let host_header = host; - let split_header = host_header.split_once(':'); - - match split_header { - Some(sh) => { - let yoyo = self.get_header(sh.0, path); - for k in yoyo.iter() { - for t in k.iter() { - _upstream_response.insert_header(t.0.clone(), t.1.clone()).unwrap(); - } - } - } - None => { - let yoyo = self.get_header(host_header, path); - for k in yoyo.iter() { - for t in k.iter() { - _upstream_response.insert_header(t.0.clone(), t.1.clone()).unwrap(); - } - } - } - } - } - None => {} - } - session.set_keepalive(Some(300)); - Ok(()) - } - - async fn logging(&self, session: &mut Session, _e: Option<&pingora_core::Error>, ctx: &mut Self::CTX) { - let response_code = session.response_written().map_or(0, |resp| resp.status.as_u16()); - - // Skip logging if disabled for this endpoint - if ctx.disable_access_log { - return; - } - - debug!("{}, response code: {response_code}", self.request_summary(session, ctx)); - - // Log TLS fingerprint if available - if let Some(ref fingerprint) = ctx.tls_fingerprint { - debug!( - "Request completed - JA4: {}, JA4_Raw: {}, TLS_Version: {}, Cipher: {:?}, SNI: {:?}, ALPN: {:?}, Response: {}", - fingerprint.ja4, - fingerprint.ja4_raw, - fingerprint.tls_version, - fingerprint.cipher_suite, - fingerprint.sni, - fingerprint.alpn, - response_code - ); - } - - let m = &crate::utils::metrics::MetricTypes { - method: session.req_header().method.to_string(), - code: session.response_written().map(|resp| resp.status.as_str().to_owned()).unwrap_or("0".to_string()), - latency: ctx.start_time.elapsed(), - version: session.req_header().version, - }; - crate::utils::metrics::calc_metrics(m); - - // Create access log - if let (Some(peer_addr), Some(local_addr)) = ( - session.client_addr().and_then(|addr| addr.as_inet()), - session.server_addr().and_then(|addr| addr.as_inet()) - ) { - let peer_socket_addr = std::net::SocketAddr::new(peer_addr.ip(), peer_addr.port()); - let local_socket_addr = std::net::SocketAddr::new(local_addr.ip(), local_addr.port()); - - // Convert request headers to hyper::http::request::Parts - let mut request_builder = http::Request::builder() - .method(session.req_header().method.as_str()) - .uri(session.req_header().uri.to_string()) - .version(session.req_header().version); - - // Copy headers - for (name, value) in session.req_header().headers.iter() { - request_builder = request_builder.header(name, value); - } - - let hyper_request = request_builder.body(()).unwrap(); - let (req_parts, _) = hyper_request.into_parts(); - - // Convert request body to Bytes - let req_body_bytes = bytes::Bytes::from(ctx.request_body.clone()); - - // Generate JA4H fingerprint from HTTP request - let ja4h_fingerprint = crate::ja4_plus::Ja4hFingerprint::from_http_request( - session.req_header().method.as_str(), - &format!("{:?}", session.req_header().version), - &session.req_header().headers - ); - - // Try to get TLS fingerprint from context or retrieve it again - // Priority: 1) Context, 2) Retrieve from storage, 3) None - let tls_fp_for_log = if let Some(tls_fp) = ctx.tls_fingerprint.as_ref() { - debug!("TLS fingerprint found in context - JA4: {}, JA4_unsorted: {}, SNI: {:?}, ALPN: {:?}", - tls_fp.ja4, tls_fp.ja4_unsorted, tls_fp.sni, tls_fp.alpn); - Some(tls_fp.clone()) - } else if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { - // Try to retrieve TLS fingerprint again if not in context - // Use fallback lookup to handle PROXY protocol address mismatches - let std_addr = std::net::SocketAddr::new(peer_addr.ip().into(), peer_addr.port()); - if let Some(fingerprint) = crate::utils::tls_client_hello::get_fingerprint_with_fallback(&std_addr) { - debug!("TLS fingerprint retrieved from storage - JA4: {}, JA4_unsorted: {}, SNI: {:?}, ALPN: {:?}", - fingerprint.ja4, fingerprint.ja4_unsorted, fingerprint.sni, fingerprint.alpn); - // Store in context for future use in this request - ctx.tls_fingerprint = Some(fingerprint.clone()); - Some(fingerprint) - } else { - debug!("No TLS fingerprint found in storage for peer: {} (this may be normal if ClientHello callback didn't fire or PROXY protocol is used)", std_addr); - None - } - } else { - debug!("No peer address available for TLS fingerprint retrieval"); - None - }; - - // Use HTTP JA4H fingerprint for tls_fingerprint parameter - // The TLS JA4 fingerprint will be passed separately via tls_ja4_unsorted - let tls_fingerprint_for_log = Some(ja4h_fingerprint.clone()); - - // Get TCP fingerprint data (if available) - let tcp_fingerprint_data = if let Some(collector) = crate::utils::tcp_fingerprint::get_global_tcp_fingerprint_collector() { - collector.lookup_fingerprint(peer_addr.ip(), peer_addr.port()) - } else { - None - }; - - // Get server certificate info (if available) - // Try hostname first, then SNI from TLS fingerprint - let server_cert_info_opt = { - let hostname_to_use = ctx.hostname.as_ref() - .or_else(|| tls_fp_for_log.as_ref().and_then(|fp| fp.sni.as_ref())); - - if let Some(hostname) = hostname_to_use { - // Try to get certificate path from certificate store - let cert_path = if let Ok(store) = crate::worker::certificate::get_certificate_store().try_read() { - if let Some(certs) = store.as_ref() { - certs.get_cert_path_for_hostname(hostname) - } else { - None - } - } else { - None - }; - - // If certificate path found, extract certificate info - if let Some(cert_path) = cert_path { - crate::utils::tls::extract_cert_info(&cert_path) - } else { - None - } - } else { - None - } - }; - - // Build upstream info - let upstream_info = ctx.upstream_peer.as_ref().map(|peer| { - crate::access_log::UpstreamInfo { - selected: peer.address.clone(), - method: "round_robin".to_string(), // TODO: Get actual method from config - reason: "healthy".to_string(), // TODO: Get actual reason - } - }); - - // Build performance info - let performance_info = crate::access_log::PerformanceInfo { - request_time_ms: Some(ctx.start_time.elapsed().as_millis() as u64), - upstream_time_ms: ctx.upstream_time.map(|d| d.as_millis() as u64), - }; - - // Build response data - let response_data = crate::access_log::ResponseData { - response_json: serde_json::json!({ - "status": response_code, - "status_text": session.response_written() - .and_then(|resp| resp.status.canonical_reason()) - .unwrap_or("Unknown"), - "content_type": session.response_written() - .and_then(|resp| resp.headers.get("content-type")) - .and_then(|h| h.to_str().ok()), - "content_length": session.response_written() - .and_then(|resp| resp.headers.get("content-length")) - .and_then(|h| h.to_str().ok()) - .and_then(|s| s.parse::().ok()) - .unwrap_or(0), - "body": "" // Response body not captured - }), - blocking_info: None, - waf_result: ctx.waf_result.clone(), - threat_data: ctx.threat_data.clone(), - }; - - // Extract SNI, ALPN, cipher, JA4, and JA4_unsorted from TLS fingerprint if available - // Use the same TLS fingerprint we retrieved above - let (tls_sni, tls_alpn, tls_cipher, tls_ja4, tls_ja4_unsorted) = if let Some(tls_fp) = tls_fp_for_log.as_ref() { - // Validate that JA4 values are not empty - let ja4 = if tls_fp.ja4.is_empty() { - warn!("TLS fingerprint found but JA4 is empty - this should not happen"); - None - } else { - Some(tls_fp.ja4.clone()) - }; - - let ja4_unsorted = if tls_fp.ja4_unsorted.is_empty() { - warn!("TLS fingerprint found but JA4_unsorted is empty - this should not happen"); - None - } else { - Some(tls_fp.ja4_unsorted.clone()) - }; - - debug!( - "TLS fingerprint found for logging - JA4: {:?}, JA4_unsorted: {:?}, SNI: {:?}, ALPN: {:?}, Cipher: {:?}", - ja4, ja4_unsorted, tls_fp.sni, tls_fp.alpn, tls_fp.cipher_suite - ); - - // Use SNI from fingerprint, fallback to hostname from context or Host header - let sni = tls_fp.sni.clone().or_else(|| { - ctx.hostname.clone().or_else(|| { - session.req_header().headers.get("host") - .and_then(|h| h.to_str().ok()) - .map(|h| h.split(':').next().unwrap_or(h).to_string()) - }) - }); - - ( - sni, - tls_fp.alpn.clone(), - tls_fp.cipher_suite.clone(), - ja4, - ja4_unsorted, - ) - } else { - debug!("No TLS fingerprint found for logging - peer: {:?} (JA4/JA4_unsorted will be null)", peer_addr); - // Fallback: try to extract SNI from Host header if available - let sni = ctx.hostname.clone().or_else(|| { - session.req_header().headers.get("host") - .and_then(|h| h.to_str().ok()) - .map(|h| h.split(':').next().unwrap_or(h).to_string()) - }); - (sni, None, None, None, None) - }; - - // Create access log with upstream and performance info - if let Err(e) = crate::access_log::HttpAccessLog::create_from_parts( - &req_parts, - &req_body_bytes, - peer_socket_addr, - local_socket_addr, - tls_fingerprint_for_log.as_ref(), - tcp_fingerprint_data.as_ref(), - server_cert_info_opt.as_ref(), - response_data, - ctx.waf_result.as_ref(), - ctx.threat_data.as_ref(), - upstream_info, - Some(performance_info), - tls_sni, - tls_alpn, - tls_cipher, - tls_ja4, - tls_ja4_unsorted, - ).await { - warn!("Failed to create access log: {}", e); - } - } - } -} - -impl LB {} - -fn return_header_host(session: &Session) -> Option { - if session.is_http2() { - match session.req_header().uri.host() { - Some(host) => Option::from(host.to_string()), - None => None, - } - } else { - match session.req_header().headers.get("host") { - Some(host) => { - let header_host = host.to_str().unwrap().splitn(2, ':').collect::>(); - Option::from(header_host[0].to_string()) - } - None => None, - } - } -} +use crate::utils::structs::{AppConfig, Extraparams, Headers, InnerMap, UpstreamsDashMap, UpstreamsIdMap}; +use crate::http_proxy::gethosts::GetHost; +use crate::waf::wirefilter::{evaluate_waf_for_pingora_request, WafAction}; +use crate::waf::actions::captcha::{validate_captcha_token, apply_captcha_challenge_with_token, generate_captcha_token}; +use arc_swap::ArcSwap; +use async_trait::async_trait; +use axum::body::Bytes; +use dashmap::DashMap; +use log::{debug, error, info, warn}; +use once_cell::sync::Lazy; +use pingora_http::{RequestHeader, ResponseHeader, StatusCode}; +use pingora_core::prelude::*; +use pingora_core::ErrorSource::{Upstream, Internal as ErrorSourceInternal}; +use pingora_core::{Error, ErrorType::HTTPStatus, RetryType, ImmutStr}; +use pingora_core::listeners::ALPN; +use pingora_core::prelude::HttpPeer; +use pingora_limits::rate::Rate; +use pingora_proxy::{ProxyHttp, Session}; +use serde_json; +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::time::Duration; +use tokio::time::Instant; +use hyper::http; + +static RATE_LIMITER: Lazy = Lazy::new(|| Rate::new(Duration::from_secs(1))); +static WAF_RATE_LIMITERS: Lazy>> = Lazy::new(|| DashMap::new()); + +#[derive(Clone)] +pub struct LB { + pub ump_upst: Arc, + pub ump_full: Arc, + pub ump_byid: Arc, + pub arxignis_paths: Arc, AtomicUsize)>>, + pub headers: Arc, + pub config: Arc, + pub extraparams: Arc>, + pub tcp_fingerprint_collector: Option>, + pub certificates: Option>>>>, +} + +pub struct Context { + backend_id: String, + start_time: Instant, + upstream_start_time: Option, + hostname: Option, + upstream_peer: Option, + extraparams: arc_swap::Guard>, + tls_fingerprint: Option>, + request_body: Vec, + malware_detected: bool, + malware_response_sent: bool, + waf_result: Option, + threat_data: Option, + upstream_time: Option, + disable_access_log: bool, +} + +#[async_trait] +impl ProxyHttp for LB { + type CTX = Context; + fn new_ctx(&self) -> Self::CTX { + Context { + backend_id: String::new(), + start_time: Instant::now(), + upstream_start_time: None, + hostname: None, + upstream_peer: None, + extraparams: self.extraparams.load(), + tls_fingerprint: None, + request_body: Vec::new(), + malware_detected: false, + malware_response_sent: false, + waf_result: None, + threat_data: None, + upstream_time: None, + disable_access_log: false, + } + } + async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result { + // Enable body buffering for content scanning + session.enable_retry_buffering(); + + let ep = _ctx.extraparams.clone(); + + // Userland access rules check (fallback when eBPF/XDP is not available) + // Check if IP is blocked by access rules + if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { + let client_ip: std::net::IpAddr = peer_addr.ip().into(); + + // Check if IP is blocked + if crate::access_rules::is_ip_blocked_by_access_rules(client_ip) { + log::info!("Userland access rules: Blocked request from IP: {} (matched block rule)", client_ip); + let mut header = ResponseHeader::build(403, None).unwrap(); + header.insert_header("X-Block-Reason", "access_rules").ok(); + session.set_keepalive(None); + session.write_response_header(Box::new(header), true).await?; + return Ok(true); + } + } + + // Try to get TLS fingerprint if available + // Use fallback lookup to handle PROXY protocol address mismatches + if _ctx.tls_fingerprint.is_none() { + if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { + let std_addr = std::net::SocketAddr::new(peer_addr.ip().into(), peer_addr.port()); + if let Some(fingerprint) = crate::utils::tls_client_hello::get_fingerprint_with_fallback(&std_addr) { + _ctx.tls_fingerprint = Some(fingerprint.clone()); + debug!( + "TLS Fingerprint retrieved for session - Peer: {}, JA4: {}, SNI: {:?}, ALPN: {:?}", + std_addr, + fingerprint.ja4, + fingerprint.sni, + fingerprint.alpn + ); + } else { + debug!("No TLS fingerprint found in storage for peer: {} (PROXY protocol may cause this)", std_addr); + } + } + } + + // Get threat intelligence data BEFORE WAF evaluation + // This ensures threat intelligence is available in access logs even when WAF blocks/challenges early + if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { + if _ctx.threat_data.is_none() { + match crate::threat::get_threat_intel(&peer_addr.ip().to_string()).await { + Ok(Some(threat_response)) => { + _ctx.threat_data = Some(threat_response); + debug!("Threat intelligence retrieved for IP: {}", peer_addr.ip()); + } + Ok(None) => { + debug!("No threat intelligence data for IP: {}", peer_addr.ip()); + } + Err(e) => { + debug!("Threat intelligence error for IP {}: {}", peer_addr.ip(), e); + } + } + } + } + + // Evaluate WAF rules + if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { + let socket_addr = std::net::SocketAddr::new(peer_addr.ip(), peer_addr.port()); + match evaluate_waf_for_pingora_request(session.req_header(), b"", socket_addr).await { + Ok(Some(waf_result)) => { + debug!("WAF rule matched: rule={}, id={}, action={:?}", waf_result.rule_name, waf_result.rule_id, waf_result.action); + + // Store threat response from WAF result if available (WAF already fetched it) + if let Some(threat_resp) = waf_result.threat_response.clone() { + _ctx.threat_data = Some(threat_resp); + debug!("Threat intelligence retrieved from WAF evaluation for IP: {}", peer_addr.ip()); + } + + // Store WAF result in context for access logging + _ctx.waf_result = Some(waf_result.clone()); + + match waf_result.action { + WafAction::Block => { + info!("WAF blocked request: rule={}, id={}, uri={}", waf_result.rule_name, waf_result.rule_id, session.req_header().uri); + let mut header = ResponseHeader::build(403, None).unwrap(); + header.insert_header("X-WAF-Rule", waf_result.rule_name).ok(); + header.insert_header("X-WAF-Rule-ID", waf_result.rule_id).ok(); + session.set_keepalive(None); + session.write_response_header(Box::new(header), true).await?; + return Ok(true); + } + WafAction::Challenge => { + info!("WAF challenge required: rule={}, id={}, uri={}", waf_result.rule_name, waf_result.rule_id, session.req_header().uri); + + // Check for captcha token in cookies or headers + let mut captcha_token: Option = None; + + // Check cookies for captcha_token + if let Some(cookies) = session.req_header().headers.get("cookie") { + if let Ok(cookie_str) = cookies.to_str() { + for cookie in cookie_str.split(';') { + let trimmed = cookie.trim(); + if let Some(value) = trimmed.strip_prefix("captcha_token=") { + captcha_token = Some(value.to_string()); + break; + } + } + } + } + + // Check X-Captcha-Token header if not found in cookies + if captcha_token.is_none() { + if let Some(token_header) = session.req_header().headers.get("x-captcha-token") { + if let Ok(token_str) = token_header.to_str() { + captcha_token = Some(token_str.to_string()); + } + } + } + + // Validate token if present + let token_valid = if let Some(token) = &captcha_token { + let user_agent = session.req_header().headers + .get("user-agent") + .and_then(|h| h.to_str().ok()) + .unwrap_or("") + .to_string(); + + match validate_captcha_token(token, &peer_addr.ip().to_string(), &user_agent).await { + Ok(valid) => { + if valid { + debug!("Captcha token validated successfully"); + } else { + debug!("Captcha token validation failed"); + } + valid + } + Err(e) => { + error!("Captcha token validation error: {}", e); + false + } + } + } else { + false + }; + + if !token_valid { + // Generate a new token (don't reuse invalid token) + let jwt_token = { + let user_agent = session.req_header().headers + .get("user-agent") + .and_then(|h| h.to_str().ok()) + .unwrap_or("") + .to_string(); + + match generate_captcha_token( + peer_addr.ip().to_string(), + user_agent, + None, // JA4 fingerprint not available here + ).await { + Ok(token) => token.token, + Err(e) => { + error!("Failed to generate captcha token: {}", e); + // Fallback to challenge without token + match apply_captcha_challenge_with_token("") { + Ok(html) => { + let mut header = ResponseHeader::build(403, None).unwrap(); + header.insert_header("Content-Type", "text/html; charset=utf-8").ok(); + session.set_keepalive(None); + session.write_response_header(Box::new(header), false).await?; + session.write_response_body(Some(Bytes::from(html)), true).await?; + return Ok(true); + } + Err(e) => { + error!("Failed to apply captcha challenge: {}", e); + // Block the request if captcha fails + let mut header = ResponseHeader::build(403, None).unwrap(); + header.insert_header("X-WAF-Rule", waf_result.rule_name).ok(); + header.insert_header("X-WAF-Rule-ID", waf_result.rule_id).ok(); + session.set_keepalive(None); + session.write_response_header(Box::new(header), true).await?; + return Ok(true); + } + } + } + } + }; + + // Return captcha challenge page + match apply_captcha_challenge_with_token(&jwt_token) { + Ok(html) => { + let mut header = ResponseHeader::build(403, None).unwrap(); + header.insert_header("Content-Type", "text/html; charset=utf-8").ok(); + header.insert_header("Set-Cookie", format!("captcha_token={}; Path=/; HttpOnly; SameSite=Lax", jwt_token)).ok(); + header.insert_header("X-WAF-Rule", waf_result.rule_name).ok(); + header.insert_header("X-WAF-Rule-ID", waf_result.rule_id).ok(); + session.set_keepalive(None); + session.write_response_header(Box::new(header), false).await?; + session.write_response_body(Some(Bytes::from(html)), true).await?; + return Ok(true); + } + Err(e) => { + error!("Failed to apply captcha challenge: {}", e); + // Block the request if captcha fails + let mut header = ResponseHeader::build(403, None).unwrap(); + header.insert_header("X-WAF-Rule", waf_result.rule_name).ok(); + header.insert_header("X-WAF-Rule-ID", waf_result.rule_id).ok(); + session.set_keepalive(None); + session.write_response_header(Box::new(header), true).await?; + return Ok(true); + } + } + } else { + // Token is valid, allow request to continue + debug!("Captcha token validated, allowing request"); + } + } + WafAction::RateLimit => { + // Get rate limit config from waf_result + if let Some(rate_limit_config) = &waf_result.rate_limit_config { + let period_secs = rate_limit_config.period_secs(); + let requests_limit = rate_limit_config.requests_count(); + + // Get or create rate limiter for this rule + let rate_limiter = WAF_RATE_LIMITERS + .entry(waf_result.rule_id.clone()) + .or_insert_with(|| { + debug!("Creating new rate limiter for rule {}: {} requests per {} seconds", + waf_result.rule_id, requests_limit, period_secs); + Arc::new(Rate::new(Duration::from_secs(period_secs))) + }) + .clone(); + + // Use client IP as the rate key + let rate_key = peer_addr.ip().to_string(); + let curr_window_requests = rate_limiter.observe(&rate_key, 1); + + if curr_window_requests > requests_limit as isize { + info!("Rate limit exceeded: rule={}, id={}, ip={}, requests={}/{}", + waf_result.rule_name, waf_result.rule_id, rate_key, curr_window_requests, requests_limit); + + let body = serde_json::json!({ + "error": "Too Many Requests", + "message": format!("Rate limit exceeded: {} requests per {} seconds", requests_limit, period_secs), + "rule": &waf_result.rule_name, + "rule_id": &waf_result.rule_id + }).to_string(); + + let mut header = ResponseHeader::build(429, None).unwrap(); + header.insert_header("X-Rate-Limit-Limit", requests_limit.to_string()).ok(); + header.insert_header("X-Rate-Limit-Remaining", "0").ok(); + header.insert_header("X-Rate-Limit-Reset", period_secs.to_string()).ok(); + header.insert_header("X-WAF-Rule", &waf_result.rule_name).ok(); + header.insert_header("X-WAF-Rule-ID", &waf_result.rule_id).ok(); + header.insert_header("Content-Type", "application/json").ok(); + + session.set_keepalive(None); + session.write_response_header(Box::new(header), false).await?; + session.write_response_body(Some(Bytes::from(body)), true).await?; + return Ok(true); + } else { + debug!("Rate limit check passed: rule={}, id={}, ip={}, requests={}/{}", + waf_result.rule_name, waf_result.rule_id, rate_key, curr_window_requests, requests_limit); + } + } else { + warn!("Rate limit action triggered but no config found for rule {}", waf_result.rule_id); + } + } + WafAction::Allow => { + debug!("WAF allowed request: rule={}, id={}", waf_result.rule_name, waf_result.rule_id); + // Allow the request to continue + } + } + } + Ok(None) => { + // No WAF rules matched, allow request to continue + debug!("WAF: No rules matched for uri={}", session.req_header().uri); + } + Err(e) => { + error!("WAF evaluation error: {}", e); + // On error, allow request to continue (fail open) + } + } + } else { + debug!("WAF: No peer address available for request"); + } + + let hostname = return_header_host(&session); + _ctx.hostname = hostname; + + let mut backend_id = None; + + if ep.sticky_sessions { + if let Some(cookies) = session.req_header().headers.get("cookie") { + if let Ok(cookie_str) = cookies.to_str() { + for cookie in cookie_str.split(';') { + let trimmed = cookie.trim(); + if let Some(value) = trimmed.strip_prefix("backend_id=") { + backend_id = Some(value); + break; + } + } + } + } + } + + match _ctx.hostname.as_ref() { + None => return Ok(false), + Some(host) => { + // let optioninnermap = self.get_host(host.as_str(), host.as_str(), backend_id); + let optioninnermap = self.get_host(host.as_str(), session.req_header().uri.path(), backend_id); + match optioninnermap { + None => return Ok(false), + Some(ref innermap) => { + // Check for HTTPS redirect before rate limiting + if ep.https_proxy_enabled.unwrap_or(false) || innermap.https_proxy_enabled { + if let Some(stream) = session.stream() { + if stream.get_ssl().is_none() { + // HTTP request - redirect to HTTPS + let uri = session.req_header().uri.path_and_query().map_or("/", |pq| pq.as_str()); + let port = self.config.proxy_port_tls.unwrap_or(403); + let redirect_url = format!("https://{}:{}{}", host, port, uri); + let mut redirect_response = ResponseHeader::build(StatusCode::MOVED_PERMANENTLY, None)?; + redirect_response.insert_header("Location", redirect_url)?; + redirect_response.insert_header("Content-Length", "0")?; + session.set_keepalive(None); + session.write_response_header(Box::new(redirect_response), false).await?; + return Ok(true); + } + } + } + if let Some(rate) = innermap.rate_limit.or(ep.rate_limit) { + // let rate_key = session.client_addr().and_then(|addr| addr.as_inet()).map(|inet| inet.ip().to_string()).unwrap_or_else(|| host.to_string()); + let rate_key = session.client_addr().and_then(|addr| addr.as_inet()).map(|inet| inet.ip()); + let curr_window_requests = RATE_LIMITER.observe(&rate_key, 1); + if curr_window_requests > rate { + let mut header = ResponseHeader::build(429, None).unwrap(); + header.insert_header("X-Rate-Limit-Limit", rate.to_string()).unwrap(); + header.insert_header("X-Rate-Limit-Remaining", "0").unwrap(); + header.insert_header("X-Rate-Limit-Reset", "1").unwrap(); + session.set_keepalive(None); + session.write_response_header(Box::new(header), true).await?; + debug!("Rate limited: {:?}, {}", rate_key, rate); + return Ok(true); + } + } + } + } + _ctx.upstream_peer = optioninnermap.clone(); + // Set disable_access_log flag from upstream config + if let Some(ref innermap) = optioninnermap { + _ctx.disable_access_log = innermap.disable_access_log; + } + } + } + Ok(false) + } + async fn upstream_peer(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result> { + // Check if malware was detected, send JSON response and prevent forwarding + if ctx.malware_detected && !ctx.malware_response_sent { + // Check if response has already been written + if session.response_written().is_some() { + warn!("Response already written, cannot block malware request in upstream_peer"); + ctx.malware_response_sent = true; + return Err(Box::new(Error { + etype: HTTPStatus(403), + esource: Upstream, + retry: RetryType::Decided(false), + cause: None, + context: Option::from(ImmutStr::Static("Malware detected")), + })); + } + + info!("Blocking request due to malware detection"); + + // Build JSON response + let json_response = serde_json::json!({ + "success": false, + "error": "Request blocked", + "reason": "malware_detected", + "message": "Malware detected in request" + }); + let json_body = Bytes::from(json_response.to_string()); + + // Build response header + let mut header = ResponseHeader::build(403, None).unwrap(); + header.insert_header("Content-Type", "application/json").ok(); + header.insert_header("X-Content-Scan-Result", "malware_detected").ok(); + + session.set_keepalive(None); + + // Try to write response, handle error if response already sent + match session.write_response_header(Box::new(header), false).await { + Ok(_) => { + match session.write_response_body(Some(json_body), true).await { + Ok(_) => { + ctx.malware_response_sent = true; + } + Err(e) => { + warn!("Failed to write response body for malware block in upstream_peer: {}", e); + ctx.malware_response_sent = true; + } + } + } + Err(e) => { + warn!("Failed to write response header for malware block in upstream_peer: {}", e); + ctx.malware_response_sent = true; + } + } + + return Err(Box::new(Error { + etype: HTTPStatus(403), + esource: Upstream, + retry: RetryType::Decided(false), + cause: None, + context: Option::from(ImmutStr::Static("Malware detected")), + })); + } + + // let host_name = return_header_host(&session); + match ctx.hostname.as_ref() { + Some(_hostname) => { + match ctx.upstream_peer.as_ref() { + // Some((address, port, ssl, is_h2, https_proxy_enabled)) => { + Some(innermap) => { + let mut peer = Box::new(HttpPeer::new((innermap.address.clone(), innermap.port.clone()), innermap.ssl_enabled, String::new())); + // if session.is_http2() { + if innermap.http2_enabled { + peer.options.alpn = ALPN::H2; + } + if innermap.ssl_enabled { + // Use upstream server address for SNI, not the incoming hostname + // This ensures the SNI matches what the upstream server expects + peer.sni = innermap.address.clone(); + peer.options.verify_cert = false; + peer.options.verify_hostname = false; + } + + ctx.backend_id = format!("{}:{}:{}", innermap.address.clone(), innermap.port.clone(), innermap.ssl_enabled); + Ok(peer) + } + None => { + if let Err(e) = session.respond_error_with_body(502, Bytes::from("502 Bad Gateway\n")).await { + error!("Failed to send error response: {:?}", e); + } + Err(Box::new(Error { + etype: HTTPStatus(502), + esource: Upstream, + retry: RetryType::Decided(false), + cause: None, + context: Option::from(ImmutStr::Static("Upstream not found")), + })) + } + } + } + None => { + // session.respond_error_with_body(502, Bytes::from("502 Bad Gateway\n")).await.expect("Failed to send error"); + if let Err(e) = session.respond_error_with_body(502, Bytes::from("502 Bad Gateway\n")).await { + error!("Failed to send error response: {:?}", e); + } + Err(Box::new(Error { + etype: HTTPStatus(502), + esource: Upstream, + retry: RetryType::Decided(false), + cause: None, + context: None, + })) + } + } + } + + async fn upstream_request_filter(&self, _session: &mut Session, upstream_request: &mut RequestHeader, ctx: &mut Self::CTX) -> Result<()> { + // Track when we start upstream request + ctx.upstream_start_time = Some(Instant::now()); + + // Check if config has a Host header before setting default + let mut config_has_host = false; + if let Some(hostname) = ctx.hostname.as_ref() { + let path = _session.req_header().uri.path(); + if let Some(configured_headers) = self.get_header(hostname, path) { + for (key, _) in configured_headers.iter() { + if key.eq_ignore_ascii_case("Host") { + config_has_host = true; + break; + } + } + } + } + + // Only set default Host if config doesn't override it + if !config_has_host { + if let Some(hostname) = ctx.hostname.as_ref() { + upstream_request.insert_header("Host", hostname)?; + } + } + + if let Some(peer) = ctx.upstream_peer.as_ref() { + upstream_request.insert_header("X-Forwarded-For", peer.address.as_str())?; + } + + // Apply configured headers from upstreams.yaml (will override default Host if present) + if let Some(hostname) = ctx.hostname.as_ref() { + let path = _session.req_header().uri.path(); + if let Some(configured_headers) = self.get_header(hostname, path) { + for (key, value) in configured_headers { + // insert_header will override existing headers with the same name + let key_clone = key.clone(); + let value_clone = value.clone(); + if let Err(e) = upstream_request.insert_header(key_clone, value_clone) { + debug!("Failed to insert header {}: {}", key, e); + } + } + } + } + + Ok(()) + } + + + async fn request_body_filter(&self, _session: &mut Session, body: &mut Option, end_of_stream: bool, ctx: &mut Self::CTX) -> Result<()> + where + Self::CTX: Send + Sync, + { + // Accumulate request body for content scanning + // Copy the body data but don't take it - Pingora will forward it if no malware + if let Some(body_bytes) = body { + info!("BODY CHUNK received: {} bytes, total so far: {}, end_of_stream: {}", body_bytes.len(), ctx.request_body.len() + body_bytes.len(), end_of_stream); + ctx.request_body.extend_from_slice(body_bytes); + } + + if end_of_stream && !ctx.request_body.is_empty() { + if let Some(scanner) = crate::content_scanning::get_global_content_scanner() { + // Get peer address for scanning + let peer_addr = if let Some(addr) = _session.client_addr().and_then(|a| a.as_inet()) { + std::net::SocketAddr::new(addr.ip(), addr.port()) + } else { + return Ok(()); // Can't scan without peer address + }; + + // Convert request header to Parts for should_scan check + let req_header = _session.req_header(); + let method = req_header.method.as_str(); + let uri = req_header.uri.to_string(); + let mut req_builder = hyper::http::Request::builder() + .method(method) + .uri(&uri); + + // Copy essential headers for content scanning (content-type, content-length) + if let Some(content_type) = req_header.headers.get("content-type") { + if let Ok(ct_str) = content_type.to_str() { + req_builder = req_builder.header("content-type", ct_str); + } + } + if let Some(content_length) = req_header.headers.get("content-length") { + if let Ok(cl_str) = content_length.to_str() { + req_builder = req_builder.header("content-length", cl_str); + } + } + + let req = match req_builder.body(()) { + Ok(req) => req, + Err(_) => { + warn!("Failed to build request for content scanning, skipping scan"); + return Ok(()); + } + }; + let (req_parts, _) = req.into_parts(); + + // Check if we should scan this request + info!("Content scanner: checking if should scan - body size: {}, method: {}, content-type: {:?}", + ctx.request_body.len(), req_parts.method, req_parts.headers.get("content-type")); + let should_scan = scanner.should_scan(&req_parts, &ctx.request_body, peer_addr); + if should_scan { + info!("Content scanner: WILL SCAN request body (size: {} bytes)", ctx.request_body.len()); + + // Check if content-type is multipart and scan accordingly + let content_type = req_parts.headers + .get("content-type") + .and_then(|h| h.to_str().ok()); + + let scan_result = if let Some(ct) = content_type { + info!("Content-Type header: {}", ct); + if let Some(boundary) = crate::content_scanning::extract_multipart_boundary(ct) { + info!("Detected multipart content with boundary: '{}', scanning parts individually", boundary); + scanner.scan_multipart_content(&ctx.request_body, &boundary).await + } else { + info!("Not multipart or no boundary found, scanning as single blob"); + scanner.scan_content(&ctx.request_body).await + } + } else { + info!("No Content-Type header, scanning as single blob"); + scanner.scan_content(&ctx.request_body).await + }; + + match scan_result { + Ok(scan_result) => { + if scan_result.malware_detected { + info!("Malware detected in request from {}: {} {} - signature: {:?}", + peer_addr, method, uri, scan_result.signature); + + // Mark malware detected in context + ctx.malware_detected = true; + + // Send 403 response immediately to block the request + let json_response = serde_json::json!({ + "success": false, + "error": "Request blocked", + "reason": "malware_detected", + "message": "Malware detected in request" + }); + let json_body = Bytes::from(json_response.to_string()); + + let mut header = ResponseHeader::build(403, None)?; + header.insert_header("Content-Type", "application/json")?; + header.insert_header("X-Content-Scan-Result", "malware_detected")?; + + _session.set_keepalive(None); + _session.write_response_header(Box::new(header), false).await?; + _session.write_response_body(Some(json_body), true).await?; + + ctx.malware_response_sent = true; + + // Return error to abort the request + return Err(Box::new(Error { + etype: HTTPStatus(403), + esource: ErrorSourceInternal, + retry: RetryType::Decided(false), + cause: None, + context: Option::from(ImmutStr::Static("Malware detected")), + })); + } else { + debug!("Content scan completed: no malware detected"); + } + } + Err(e) => { + warn!("Content scanning failed: {}", e); + // On scanning error, allow the request to proceed (fail open) + } + } + } else { + debug!("Content scanner: skipping scan (should_scan returned false)"); + } + } + } + + Ok(()) + } + + async fn response_filter(&self, session: &mut Session, _upstream_response: &mut ResponseHeader, ctx: &mut Self::CTX) -> Result<()> { + // Calculate upstream response time + if let Some(upstream_start) = ctx.upstream_start_time { + ctx.upstream_time = Some(upstream_start.elapsed()); + } + + // _upstream_response.insert_header("X-Proxied-From", "Fooooooooooooooo").unwrap(); + if ctx.extraparams.sticky_sessions { + let backend_id = ctx.backend_id.clone(); + if let Some(bid) = self.ump_byid.get(&backend_id) { + let _ = _upstream_response.insert_header("set-cookie", format!("backend_id={}; Path=/; Max-Age=600; HttpOnly; SameSite=Lax", bid.address)); + } + } + match ctx.hostname.as_ref() { + Some(host) => { + let path = session.req_header().uri.path(); + let host_header = host; + let split_header = host_header.split_once(':'); + + match split_header { + Some(sh) => { + let yoyo = self.get_header(sh.0, path); + for k in yoyo.iter() { + for t in k.iter() { + _upstream_response.insert_header(t.0.clone(), t.1.clone()).unwrap(); + } + } + } + None => { + let yoyo = self.get_header(host_header, path); + for k in yoyo.iter() { + for t in k.iter() { + _upstream_response.insert_header(t.0.clone(), t.1.clone()).unwrap(); + } + } + } + } + } + None => {} + } + session.set_keepalive(Some(300)); + Ok(()) + } + + async fn logging(&self, session: &mut Session, _e: Option<&pingora_core::Error>, ctx: &mut Self::CTX) { + let response_code = session.response_written().map_or(0, |resp| resp.status.as_u16()); + + // Skip logging if disabled for this endpoint + if ctx.disable_access_log { + return; + } + + debug!("{}, response code: {response_code}", self.request_summary(session, ctx)); + + // Log TLS fingerprint if available + if let Some(ref fingerprint) = ctx.tls_fingerprint { + debug!( + "Request completed - JA4: {}, JA4_Raw: {}, TLS_Version: {}, Cipher: {:?}, SNI: {:?}, ALPN: {:?}, Response: {}", + fingerprint.ja4, + fingerprint.ja4_raw, + fingerprint.tls_version, + fingerprint.cipher_suite, + fingerprint.sni, + fingerprint.alpn, + response_code + ); + } + + let m = &crate::utils::metrics::MetricTypes { + method: session.req_header().method.to_string(), + code: session.response_written().map(|resp| resp.status.as_str().to_owned()).unwrap_or("0".to_string()), + latency: ctx.start_time.elapsed(), + version: session.req_header().version, + }; + crate::utils::metrics::calc_metrics(m); + + // Create access log + if let (Some(peer_addr), Some(local_addr)) = ( + session.client_addr().and_then(|addr| addr.as_inet()), + session.server_addr().and_then(|addr| addr.as_inet()) + ) { + let peer_socket_addr = std::net::SocketAddr::new(peer_addr.ip(), peer_addr.port()); + let local_socket_addr = std::net::SocketAddr::new(local_addr.ip(), local_addr.port()); + + // Convert request headers to hyper::http::request::Parts + let mut request_builder = http::Request::builder() + .method(session.req_header().method.as_str()) + .uri(session.req_header().uri.to_string()) + .version(session.req_header().version); + + // Copy headers + for (name, value) in session.req_header().headers.iter() { + request_builder = request_builder.header(name, value); + } + + let hyper_request = request_builder.body(()).unwrap(); + let (req_parts, _) = hyper_request.into_parts(); + + // Convert request body to Bytes + let req_body_bytes = bytes::Bytes::from(ctx.request_body.clone()); + + // Generate JA4H fingerprint from HTTP request + let ja4h_fingerprint = crate::ja4_plus::Ja4hFingerprint::from_http_request( + session.req_header().method.as_str(), + &format!("{:?}", session.req_header().version), + &session.req_header().headers + ); + + // Try to get TLS fingerprint from context or retrieve it again + // Priority: 1) Context, 2) Retrieve from storage, 3) None + let tls_fp_for_log = if let Some(tls_fp) = ctx.tls_fingerprint.as_ref() { + debug!("TLS fingerprint found in context - JA4: {}, JA4_unsorted: {}, SNI: {:?}, ALPN: {:?}", + tls_fp.ja4, tls_fp.ja4_unsorted, tls_fp.sni, tls_fp.alpn); + Some(tls_fp.clone()) + } else if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { + // Try to retrieve TLS fingerprint again if not in context + // Use fallback lookup to handle PROXY protocol address mismatches + let std_addr = std::net::SocketAddr::new(peer_addr.ip().into(), peer_addr.port()); + if let Some(fingerprint) = crate::utils::tls_client_hello::get_fingerprint_with_fallback(&std_addr) { + debug!("TLS fingerprint retrieved from storage - JA4: {}, JA4_unsorted: {}, SNI: {:?}, ALPN: {:?}", + fingerprint.ja4, fingerprint.ja4_unsorted, fingerprint.sni, fingerprint.alpn); + // Store in context for future use in this request + ctx.tls_fingerprint = Some(fingerprint.clone()); + Some(fingerprint) + } else { + debug!("No TLS fingerprint found in storage for peer: {} (this may be normal if ClientHello callback didn't fire or PROXY protocol is used)", std_addr); + None + } + } else { + debug!("No peer address available for TLS fingerprint retrieval"); + None + }; + + // Use HTTP JA4H fingerprint for tls_fingerprint parameter + // The TLS JA4 fingerprint will be passed separately via tls_ja4_unsorted + let tls_fingerprint_for_log = Some(ja4h_fingerprint.clone()); + + // Get TCP fingerprint data (if available) + let tcp_fingerprint_data = if let Some(collector) = crate::utils::tcp_fingerprint::get_global_tcp_fingerprint_collector() { + collector.lookup_fingerprint(peer_addr.ip(), peer_addr.port()) + } else { + None + }; + + // Get server certificate info (if available) + // Try hostname first, then SNI from TLS fingerprint + let server_cert_info_opt = { + let hostname_to_use = ctx.hostname.as_ref() + .or_else(|| tls_fp_for_log.as_ref().and_then(|fp| fp.sni.as_ref())); + + if let Some(hostname) = hostname_to_use { + // Try to get certificate path from certificate store + let cert_path = if let Ok(store) = crate::worker::certificate::get_certificate_store().try_read() { + if let Some(certs) = store.as_ref() { + certs.get_cert_path_for_hostname(hostname) + } else { + None + } + } else { + None + }; + + // If certificate path found, extract certificate info + if let Some(cert_path) = cert_path { + crate::utils::tls::extract_cert_info(&cert_path) + } else { + None + } + } else { + None + } + }; + + // Build upstream info + let upstream_info = ctx.upstream_peer.as_ref().map(|peer| { + crate::access_log::UpstreamInfo { + selected: peer.address.clone(), + method: "round_robin".to_string(), // TODO: Get actual method from config + reason: "healthy".to_string(), // TODO: Get actual reason + } + }); + + // Build performance info + let performance_info = crate::access_log::PerformanceInfo { + request_time_ms: Some(ctx.start_time.elapsed().as_millis() as u64), + upstream_time_ms: ctx.upstream_time.map(|d| d.as_millis() as u64), + }; + + // Build response data + let response_data = crate::access_log::ResponseData { + response_json: serde_json::json!({ + "status": response_code, + "status_text": session.response_written() + .and_then(|resp| resp.status.canonical_reason()) + .unwrap_or("Unknown"), + "content_type": session.response_written() + .and_then(|resp| resp.headers.get("content-type")) + .and_then(|h| h.to_str().ok()), + "content_length": session.response_written() + .and_then(|resp| resp.headers.get("content-length")) + .and_then(|h| h.to_str().ok()) + .and_then(|s| s.parse::().ok()) + .unwrap_or(0), + "body": "" // Response body not captured + }), + blocking_info: None, + waf_result: ctx.waf_result.clone(), + threat_data: ctx.threat_data.clone(), + }; + + // Extract SNI, ALPN, cipher, JA4, and JA4_unsorted from TLS fingerprint if available + // Use the same TLS fingerprint we retrieved above + let (tls_sni, tls_alpn, tls_cipher, tls_ja4, tls_ja4_unsorted) = if let Some(tls_fp) = tls_fp_for_log.as_ref() { + // Validate that JA4 values are not empty + let ja4 = if tls_fp.ja4.is_empty() { + warn!("TLS fingerprint found but JA4 is empty - this should not happen"); + None + } else { + Some(tls_fp.ja4.clone()) + }; + + let ja4_unsorted = if tls_fp.ja4_unsorted.is_empty() { + warn!("TLS fingerprint found but JA4_unsorted is empty - this should not happen"); + None + } else { + Some(tls_fp.ja4_unsorted.clone()) + }; + + debug!( + "TLS fingerprint found for logging - JA4: {:?}, JA4_unsorted: {:?}, SNI: {:?}, ALPN: {:?}, Cipher: {:?}", + ja4, ja4_unsorted, tls_fp.sni, tls_fp.alpn, tls_fp.cipher_suite + ); + + // Use SNI from fingerprint, fallback to hostname from context or Host header + let sni = tls_fp.sni.clone().or_else(|| { + ctx.hostname.clone().or_else(|| { + session.req_header().headers.get("host") + .and_then(|h| h.to_str().ok()) + .map(|h| h.split(':').next().unwrap_or(h).to_string()) + }) + }); + + ( + sni, + tls_fp.alpn.clone(), + tls_fp.cipher_suite.clone(), + ja4, + ja4_unsorted, + ) + } else { + debug!("No TLS fingerprint found for logging - peer: {:?} (JA4/JA4_unsorted will be null)", peer_addr); + // Fallback: try to extract SNI from Host header if available + let sni = ctx.hostname.clone().or_else(|| { + session.req_header().headers.get("host") + .and_then(|h| h.to_str().ok()) + .map(|h| h.split(':').next().unwrap_or(h).to_string()) + }); + (sni, None, None, None, None) + }; + + // Create access log with upstream and performance info + if let Err(e) = crate::access_log::HttpAccessLog::create_from_parts( + &req_parts, + &req_body_bytes, + peer_socket_addr, + local_socket_addr, + tls_fingerprint_for_log.as_ref(), + tcp_fingerprint_data.as_ref(), + server_cert_info_opt.as_ref(), + response_data, + ctx.waf_result.as_ref(), + ctx.threat_data.as_ref(), + upstream_info, + Some(performance_info), + tls_sni, + tls_alpn, + tls_cipher, + tls_ja4, + tls_ja4_unsorted, + ).await { + warn!("Failed to create access log: {}", e); + } + } + } +} + +impl LB {} + +fn return_header_host(session: &Session) -> Option { + if session.is_http2() { + match session.req_header().uri.host() { + Some(host) => Option::from(host.to_string()), + None => None, + } + } else { + match session.req_header().headers.get("host") { + Some(host) => { + let header_host = host.to_str().unwrap().splitn(2, ':').collect::>(); + Option::from(header_host[0].to_string()) + } + None => None, + } + } +} diff --git a/src/http_proxy/start.rs b/src/http_proxy/start.rs index bf20614..ba7f87f 100644 --- a/src/http_proxy/start.rs +++ b/src/http_proxy/start.rs @@ -1,285 +1,285 @@ -// use rustls::crypto::ring::default_provider; -use crate::http_proxy::proxyhttp::LB; -use crate::utils::structs::Extraparams; -use crate::utils::tls; -use crate::utils::tls::CertificateConfig; - -use arc_swap::ArcSwap; -use ctrlc; -use dashmap::DashMap; -use log::{debug, info, warn}; -use pingora_core::listeners::tls::TlsSettings; -use pingora_core::prelude::{background_service, Opt}; -use pingora_core::server::Server; -use std::fs; -use std::process; -use std::sync::mpsc::{channel, Receiver, Sender}; -use std::sync::Arc; -use std::thread; - -pub fn run() { - run_with_config(None) -} - -pub fn run_with_config(config: Option) { - // default_provider().install_default().expect("Failed to install rustls crypto provider"); - let maincfg = if let Some(cfg) = config { - cfg.pingora.to_app_config() - } else { - // Fallback to old parsing method for backward compatibility - let parameters = Some(Opt::parse_args()).unwrap(); - let file = parameters.conf.clone().unwrap(); - crate::utils::parceyaml::parce_main_config(file.as_str()) - }; - - // Skip old proxy system if no proxy addresses are configured (using new config format) - if maincfg.proxy_address_http.is_empty() { - info!("Pingora proxy system disabled (no proxy_address_http configured)"); - info!( - "Using new HTTP server on: {}:{}", - maincfg.proxy_address_http, - maincfg.proxy_port_tls.unwrap_or(443) - ); - return; - } - - info!( - "Starting Pingora proxy system on HTTP: {}", - maincfg.proxy_address_http - ); - if let Some(ref tls_addr) = maincfg.proxy_address_tls { - info!("Pingora proxy TLS enabled on: {}", tls_addr); - } - - // Pass None to avoid pingora parsing the config file (we use our own parser above) - let mut server = Server::new(None).unwrap(); - - // Store proxy_protocol_enabled before moving maincfg - let proxy_protocol_enabled = maincfg.proxy_protocol_enabled; - - // Use proxy_protocol_enabled from config - if proxy_protocol_enabled { - info!("PROXY protocol support enabled - Pingora will parse headers before HTTP/TLS"); - info!("WARNING: All incoming connections must include PROXY protocol headers when enabled"); - info!("Direct connections without PROXY headers will fail. Ensure load balancer sends PROXY headers."); - info!("PROXY protocol will be parsed before TLS handshake for secure connections"); - // Enable PROXY protocol globally in Pingora - pingora_core::protocols::proxy_protocol::set_proxy_protocol_enabled(true); - info!("PROXY protocol enabled globally for all TCP and TLS listeners"); - } else { - info!("PROXY protocol support disabled - direct connections allowed"); - } - - server.bootstrap(); - - let uf_config = Arc::new(DashMap::new()); - let ff_config = Arc::new(DashMap::new()); - let im_config = Arc::new(DashMap::new()); - let hh_config = Arc::new(DashMap::new()); - let ap_config = Arc::new(DashMap::new()); - - let ec_config = Arc::new(ArcSwap::from_pointee(Extraparams { - sticky_sessions: false, - https_proxy_enabled: None, - authentication: DashMap::new(), - rate_limit: None, - })); - - let cfg = Arc::new(maincfg); - - let certificates_arc: Arc>>> = - Arc::new(ArcSwap::from_pointee(None)); - - let lb = LB { - ump_upst: uf_config, - ump_full: ff_config, - ump_byid: im_config, - arxignis_paths: ap_config, - config: cfg.clone(), - headers: hh_config, - extraparams: ec_config, - tcp_fingerprint_collector: None, // TODO: Pass from main.rs if available - certificates: Some(certificates_arc.clone()), - }; - - let grade = cfg - .proxy_tls_grade - .clone() - .unwrap_or_else(|| "medium".to_string()); - info!("TLS grade set to: [ {} ]", grade); - - let bg_srvc = background_service("bgsrvc", lb.clone()); - let mut proxy = pingora_proxy::http_proxy_service(&server.configuration, lb.clone()); - let bind_address_http = cfg.proxy_address_http.clone(); - let bind_address_tls = cfg.proxy_address_tls.clone(); - - crate::utils::tools::check_priv(bind_address_http.as_str()); - - match bind_address_tls { - Some(bind_address_tls) => { - crate::utils::tools::check_priv(bind_address_tls.as_str()); - let (tx, rx): (Sender>, Receiver>) = - channel(); - let certs_path = cfg.proxy_certificates.clone().unwrap(); - - // Check if directory exists before watching - let certs_path_exists = fs::metadata(&certs_path).is_ok(); - let certs_path_clone = certs_path.clone(); - - if certs_path_exists { - // Start watcher thread - it will send initial configs - thread::spawn(move || { - if let Err(e) = crate::utils::tools::watch_folder(certs_path_clone, tx) { - warn!("Failed to watch certificate directory: {:?}", e); - } - }); - } else { - warn!("Certificate directory does not exist: {}. TLS will be disabled until certificates are added.", certs_path); - // Send empty configs so receiver doesn't block - if tx.send(vec![]).is_err() { - warn!("Failed to send initial certificate configs"); - } - } - - // Receive initial certificate configs - let certificate_configs = match rx.recv() { - Ok(configs) => configs, - Err(e) => { - warn!("Failed to receive certificate configs: {:?}. TLS will be disabled.", e); - vec![] - } - }; - - if let Some(first_set) = tls::Certificates::new( - &certificate_configs, - grade.as_str(), - cfg.default_certificate.as_ref(), - ) { - let first_set_arc: Arc = Arc::new(first_set); - certificates_arc.store(Arc::new(Some(first_set_arc.clone()) - as Option>)); - - // Set global certificates for SNI callback - tls::set_global_certificates(first_set_arc.clone()); - - let default_cert_path = first_set_arc.default_cert_path.clone(); - let default_key_path = first_set_arc.default_key_path.clone(); - - // Create TlsSettings with SNI callback for certificate selection - let tls_settings = match tls::create_tls_settings_with_sni( - &default_cert_path, - &default_key_path, - grade.as_str(), - Some(first_set_arc.clone()), - ) { - Ok(settings) => settings, - Err(e) => { - warn!( - "Failed to create TlsSettings with SNI callback: {}, falling back to default", - e - ); - let mut settings = TlsSettings::intermediate( - &default_cert_path, - &default_key_path, - ) - .expect("unable to load or parse cert/key"); - tls::set_tsl_grade(&mut settings, grade.as_str()); - tls::set_alpn_prefer_h2(&mut settings); - settings - } - }; - - // Register ClientHello callback to generate fingerprints - // Note: When PROXY protocol is enabled, ClientHello extraction may fail if the connection - // is reset before TLS handshake completes. The "Failed to peek at socket" warnings are - // expected in this case and are non-fatal - the TLS handshake will still proceed. - #[cfg(unix)] - { - use log::info; - use pingora_core::listeners::set_client_hello_callback; - use pingora_core::protocols::l4::socket::SocketAddr; - use pingora_core::protocols::tls::client_hello::ClientHello; - - set_client_hello_callback(Some( - |hello: &ClientHello, peer_addr: Option| { - let peer_str = peer_addr - .as_ref() - .and_then(|a| a.as_inet()) - .map(|inet| format!("{}:{}", inet.ip(), inet.port())) - .unwrap_or_else(|| "unknown".to_string()); - debug!( - "ClientHello callback invoked for peer: {}, SNI: {:?}, ALPN: {:?}, raw_len={}", - peer_str, - hello.sni, - hello.alpn, - hello.raw.len() - ); - // Generate fingerprint from ClientHello - if let Some(_fp) = - crate::utils::tls_client_hello::generate_fingerprint_from_client_hello( - hello, - peer_addr, - ) - { - debug!("Fingerprint generated successfully for peer: {}", peer_str); - } else { - // Log at debug level - failures are more common with PROXY protocol - // due to connection resets, but this is non-fatal - debug!("Failed to generate fingerprint for peer: {} (non-fatal, TLS handshake will continue)", peer_str); - } - }, - )); - if proxy_protocol_enabled { - info!("TLS ClientHello callback registered for fingerprint generation (PROXY protocol enabled - some extraction failures are expected and non-fatal)"); - } else { - info!("TLS ClientHello callback registered for fingerprint generation"); - } - } - - proxy.add_tls_with_settings(&bind_address_tls, None, tls_settings); - } else { - info!("TLS listener disabled: no certificates found in directory. TLS will be enabled when certificates are added."); - } - - let certs_for_watcher = certificates_arc.clone(); - let default_cert_for_watcher = cfg.default_certificate.clone(); - thread::spawn(move || { - while let Ok(new_configs) = rx.recv() { - let new_certs = - tls::Certificates::new(&new_configs, grade.as_str(), default_cert_for_watcher.as_ref()); - if let Some(new_certs) = new_certs { - certs_for_watcher.store(Arc::new(Some(Arc::new(new_certs)))); - } - } - }); - } - None => {} - } - - info!("Running HTTP listener on :{}", bind_address_http.as_str()); - proxy.add_tcp(bind_address_http.as_str()); - - server.add_service(proxy); - server.add_service(bg_srvc); - - thread::spawn(move || server.run_forever()); - - if let (Some(user), Some(group)) = (cfg.rungroup.clone(), cfg.runuser.clone()) { - crate::utils::tools::drop_priv( - user, - group, - cfg.proxy_address_http.clone(), - cfg.proxy_address_tls.clone(), - ); - } - - let (tx, rx) = channel(); - ctrlc::set_handler(move || { - let _ = tx.send(()); - }) - .expect("Error setting Ctrl-C handler"); - rx.recv() - .expect("Could not receive from channel."); - info!("Signal received ! Exiting..."); - process::exit(0); -} +// use rustls::crypto::ring::default_provider; +use crate::http_proxy::proxyhttp::LB; +use crate::utils::structs::Extraparams; +use crate::utils::tls; +use crate::utils::tls::CertificateConfig; + +use arc_swap::ArcSwap; +use ctrlc; +use dashmap::DashMap; +use log::{debug, info, warn}; +use pingora_core::listeners::tls::TlsSettings; +use pingora_core::prelude::{background_service, Opt}; +use pingora_core::server::Server; +use std::fs; +use std::process; +use std::sync::mpsc::{channel, Receiver, Sender}; +use std::sync::Arc; +use std::thread; + +pub fn run() { + run_with_config(None) +} + +pub fn run_with_config(config: Option) { + // default_provider().install_default().expect("Failed to install rustls crypto provider"); + let maincfg = if let Some(cfg) = config { + cfg.pingora.to_app_config() + } else { + // Fallback to old parsing method for backward compatibility + let parameters = Some(Opt::parse_args()).unwrap(); + let file = parameters.conf.clone().unwrap(); + crate::utils::parceyaml::parce_main_config(file.as_str()) + }; + + // Skip old proxy system if no proxy addresses are configured (using new config format) + if maincfg.proxy_address_http.is_empty() { + info!("Pingora proxy system disabled (no proxy_address_http configured)"); + info!( + "Using new HTTP server on: {}:{}", + maincfg.proxy_address_http, + maincfg.proxy_port_tls.unwrap_or(443) + ); + return; + } + + info!( + "Starting Pingora proxy system on HTTP: {}", + maincfg.proxy_address_http + ); + if let Some(ref tls_addr) = maincfg.proxy_address_tls { + info!("Pingora proxy TLS enabled on: {}", tls_addr); + } + + // Pass None to avoid pingora parsing the config file (we use our own parser above) + let mut server = Server::new(None).unwrap(); + + // Store proxy_protocol_enabled before moving maincfg + let proxy_protocol_enabled = maincfg.proxy_protocol_enabled; + + // Use proxy_protocol_enabled from config + if proxy_protocol_enabled { + info!("PROXY protocol support enabled - Pingora will parse headers before HTTP/TLS"); + info!("WARNING: All incoming connections must include PROXY protocol headers when enabled"); + info!("Direct connections without PROXY headers will fail. Ensure load balancer sends PROXY headers."); + info!("PROXY protocol will be parsed before TLS handshake for secure connections"); + // Enable PROXY protocol globally in Pingora + pingora_core::protocols::proxy_protocol::set_proxy_protocol_enabled(true); + info!("PROXY protocol enabled globally for all TCP and TLS listeners"); + } else { + info!("PROXY protocol support disabled - direct connections allowed"); + } + + server.bootstrap(); + + let uf_config = Arc::new(DashMap::new()); + let ff_config = Arc::new(DashMap::new()); + let im_config = Arc::new(DashMap::new()); + let hh_config = Arc::new(DashMap::new()); + let ap_config = Arc::new(DashMap::new()); + + let ec_config = Arc::new(ArcSwap::from_pointee(Extraparams { + sticky_sessions: false, + https_proxy_enabled: None, + authentication: DashMap::new(), + rate_limit: None, + })); + + let cfg = Arc::new(maincfg); + + let certificates_arc: Arc>>> = + Arc::new(ArcSwap::from_pointee(None)); + + let lb = LB { + ump_upst: uf_config, + ump_full: ff_config, + ump_byid: im_config, + arxignis_paths: ap_config, + config: cfg.clone(), + headers: hh_config, + extraparams: ec_config, + tcp_fingerprint_collector: None, // TODO: Pass from main.rs if available + certificates: Some(certificates_arc.clone()), + }; + + let grade = cfg + .proxy_tls_grade + .clone() + .unwrap_or_else(|| "medium".to_string()); + info!("TLS grade set to: [ {} ]", grade); + + let bg_srvc = background_service("bgsrvc", lb.clone()); + let mut proxy = pingora_proxy::http_proxy_service(&server.configuration, lb.clone()); + let bind_address_http = cfg.proxy_address_http.clone(); + let bind_address_tls = cfg.proxy_address_tls.clone(); + + crate::utils::tools::check_priv(bind_address_http.as_str()); + + match bind_address_tls { + Some(bind_address_tls) => { + crate::utils::tools::check_priv(bind_address_tls.as_str()); + let (tx, rx): (Sender>, Receiver>) = + channel(); + let certs_path = cfg.proxy_certificates.clone().unwrap(); + + // Check if directory exists before watching + let certs_path_exists = fs::metadata(&certs_path).is_ok(); + let certs_path_clone = certs_path.clone(); + + if certs_path_exists { + // Start watcher thread - it will send initial configs + thread::spawn(move || { + if let Err(e) = crate::utils::tools::watch_folder(certs_path_clone, tx) { + warn!("Failed to watch certificate directory: {:?}", e); + } + }); + } else { + warn!("Certificate directory does not exist: {}. TLS will be disabled until certificates are added.", certs_path); + // Send empty configs so receiver doesn't block + if tx.send(vec![]).is_err() { + warn!("Failed to send initial certificate configs"); + } + } + + // Receive initial certificate configs + let certificate_configs = match rx.recv() { + Ok(configs) => configs, + Err(e) => { + warn!("Failed to receive certificate configs: {:?}. TLS will be disabled.", e); + vec![] + } + }; + + if let Some(first_set) = tls::Certificates::new( + &certificate_configs, + grade.as_str(), + cfg.default_certificate.as_ref(), + ) { + let first_set_arc: Arc = Arc::new(first_set); + certificates_arc.store(Arc::new(Some(first_set_arc.clone()) + as Option>)); + + // Set global certificates for SNI callback + tls::set_global_certificates(first_set_arc.clone()); + + let default_cert_path = first_set_arc.default_cert_path.clone(); + let default_key_path = first_set_arc.default_key_path.clone(); + + // Create TlsSettings with SNI callback for certificate selection + let tls_settings = match tls::create_tls_settings_with_sni( + &default_cert_path, + &default_key_path, + grade.as_str(), + Some(first_set_arc.clone()), + ) { + Ok(settings) => settings, + Err(e) => { + warn!( + "Failed to create TlsSettings with SNI callback: {}, falling back to default", + e + ); + let mut settings = TlsSettings::intermediate( + &default_cert_path, + &default_key_path, + ) + .expect("unable to load or parse cert/key"); + tls::set_tsl_grade(&mut settings, grade.as_str()); + tls::set_alpn_prefer_h2(&mut settings); + settings + } + }; + + // Register ClientHello callback to generate fingerprints + // Note: When PROXY protocol is enabled, ClientHello extraction may fail if the connection + // is reset before TLS handshake completes. The "Failed to peek at socket" warnings are + // expected in this case and are non-fatal - the TLS handshake will still proceed. + #[cfg(unix)] + { + use log::info; + use pingora_core::listeners::set_client_hello_callback; + use pingora_core::protocols::l4::socket::SocketAddr; + use pingora_core::protocols::tls::client_hello::ClientHello; + + set_client_hello_callback(Some( + |hello: &ClientHello, peer_addr: Option| { + let peer_str = peer_addr + .as_ref() + .and_then(|a| a.as_inet()) + .map(|inet| format!("{}:{}", inet.ip(), inet.port())) + .unwrap_or_else(|| "unknown".to_string()); + debug!( + "ClientHello callback invoked for peer: {}, SNI: {:?}, ALPN: {:?}, raw_len={}", + peer_str, + hello.sni, + hello.alpn, + hello.raw.len() + ); + // Generate fingerprint from ClientHello + if let Some(_fp) = + crate::utils::tls_client_hello::generate_fingerprint_from_client_hello( + hello, + peer_addr, + ) + { + debug!("Fingerprint generated successfully for peer: {}", peer_str); + } else { + // Log at debug level - failures are more common with PROXY protocol + // due to connection resets, but this is non-fatal + debug!("Failed to generate fingerprint for peer: {} (non-fatal, TLS handshake will continue)", peer_str); + } + }, + )); + if proxy_protocol_enabled { + info!("TLS ClientHello callback registered for fingerprint generation (PROXY protocol enabled - some extraction failures are expected and non-fatal)"); + } else { + info!("TLS ClientHello callback registered for fingerprint generation"); + } + } + + proxy.add_tls_with_settings(&bind_address_tls, None, tls_settings); + } else { + info!("TLS listener disabled: no certificates found in directory. TLS will be enabled when certificates are added."); + } + + let certs_for_watcher = certificates_arc.clone(); + let default_cert_for_watcher = cfg.default_certificate.clone(); + thread::spawn(move || { + while let Ok(new_configs) = rx.recv() { + let new_certs = + tls::Certificates::new(&new_configs, grade.as_str(), default_cert_for_watcher.as_ref()); + if let Some(new_certs) = new_certs { + certs_for_watcher.store(Arc::new(Some(Arc::new(new_certs)))); + } + } + }); + } + None => {} + } + + info!("Running HTTP listener on :{}", bind_address_http.as_str()); + proxy.add_tcp(bind_address_http.as_str()); + + server.add_service(proxy); + server.add_service(bg_srvc); + + thread::spawn(move || server.run_forever()); + + if let (Some(user), Some(group)) = (cfg.rungroup.clone(), cfg.runuser.clone()) { + crate::utils::tools::drop_priv( + user, + group, + cfg.proxy_address_http.clone(), + cfg.proxy_address_tls.clone(), + ); + } + + let (tx, rx) = channel(); + ctrlc::set_handler(move || { + let _ = tx.send(()); + }) + .expect("Error setting Ctrl-C handler"); + rx.recv() + .expect("Could not receive from channel."); + info!("Signal received ! Exiting..."); + process::exit(0); +} diff --git a/src/http_proxy/webserver.rs b/src/http_proxy/webserver.rs index a9976de..53ccf1b 100644 --- a/src/http_proxy/webserver.rs +++ b/src/http_proxy/webserver.rs @@ -1,113 +1,113 @@ -use crate::utils::discovery::APIUpstreamProvider; -use crate::utils::structs::Configuration; -use axum::body::Body; -use axum::extract::{Query, State}; -use axum::http::{Response, StatusCode}; -use axum::response::IntoResponse; -use axum::routing::{get, post}; -use axum::Router; -use axum_server::tls_openssl::OpenSSLConfig; -use futures::channel::mpsc::Sender; -use futures::SinkExt; -use log::info; -use prometheus::{gather, Encoder, TextEncoder}; -use std::collections::HashMap; -use std::net::SocketAddr; -use tokio::net::TcpListener; -use tower_http::services::ServeDir; - -#[derive(Clone)] -struct AppState { - master_key: String, - config_sender: Sender, - config_api_enabled: bool, -} - -#[allow(unused_mut)] -pub async fn run_server(config: &APIUpstreamProvider, mut to_return: Sender) { - let app_state = AppState { - master_key: config.masterkey.clone(), - config_sender: to_return.clone(), - config_api_enabled: config.config_api_enabled.clone(), - }; - - let app = Router::new() - // .route("/{*wildcard}", get(senderror)) - // .route("/{*wildcard}", post(senderror)) - // .route("/{*wildcard}", put(senderror)) - // .route("/{*wildcard}", head(senderror)) - // .route("/{*wildcard}", delete(senderror)) - // .nest_service("/static", static_files) - .route("/conf", post(conf)) - .route("/metrics", get(metrics)) - .with_state(app_state); - - if let Some(value) = &config.tls_address { - let cf = OpenSSLConfig::from_pem_file(config.tls_certificate.clone().unwrap(), config.tls_key_file.clone().unwrap()).unwrap(); - let addr: SocketAddr = value.parse().expect("Unable to parse socket address"); - let tls_app = app.clone(); - tokio::spawn(async move { - if let Err(e) = axum_server::bind_openssl(addr, cf).serve(tls_app.into_make_service()).await { - eprintln!("TLS server failed: {}", e); - } - }); - info!("Starting the TLS API server on: {}", value); - } - - if let (Some(address), Some(folder)) = (&config.file_server_address, &config.file_server_folder) { - let static_files = ServeDir::new(folder); - let static_serve: Router = Router::new().fallback_service(static_files); - let static_listen = TcpListener::bind(address).await.unwrap(); - let _ = tokio::spawn(async move { axum::serve(static_listen, static_serve).await.unwrap() }); - } - - let listener = TcpListener::bind(config.address.clone()).await.unwrap(); - info!("Starting the API server on: {}", config.address); - axum::serve(listener, app).await.unwrap(); -} - -async fn conf(State(mut st): State, Query(params): Query>, content: String) -> impl IntoResponse { - if !st.config_api_enabled { - return Response::builder() - .status(StatusCode::FORBIDDEN) - .body(Body::from("Config remote API is disabled !\n")) - .unwrap(); - } - - if let Some(s) = params.get("key") { - if s.to_owned() == st.master_key { - if let Some(serverlist) = crate::utils::parceyaml::load_configuration(content.as_str(), "content").await { - st.config_sender.send(serverlist).await.unwrap(); - return Response::builder().status(StatusCode::OK).body(Body::from("Config, conf file, updated !\n")).unwrap(); - } else { - return Response::builder().status(StatusCode::BAD_GATEWAY).body(Body::from("Failed to parse config!\n")).unwrap(); - }; - } - } - Response::builder().status(StatusCode::FORBIDDEN).body(Body::from("Access Denied !\n")).unwrap() -} - -async fn metrics() -> impl IntoResponse { - let metric_families = gather(); - let encoder = TextEncoder::new(); - - let mut buffer = Vec::new(); - if let Err(e) = encoder.encode(&metric_families, &mut buffer) { - // encoding error fallback - return Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::from(format!("Failed to encode metrics: {}", e))) - .unwrap(); - } - - Response::builder() - .status(StatusCode::OK) - .header("Content-Type", encoder.format_type()) - .body(Body::from(buffer)) - .unwrap() -} - -// #[allow(dead_code)] -// async fn senderror() -> impl IntoResponse { -// Response::builder().status(StatusCode::BAD_GATEWAY).body(Body::from("No live upstream found!\n")).unwrap() -// } +use crate::utils::discovery::APIUpstreamProvider; +use crate::utils::structs::Configuration; +use axum::body::Body; +use axum::extract::{Query, State}; +use axum::http::{Response, StatusCode}; +use axum::response::IntoResponse; +use axum::routing::{get, post}; +use axum::Router; +use axum_server::tls_openssl::OpenSSLConfig; +use futures::channel::mpsc::Sender; +use futures::SinkExt; +use log::info; +use prometheus::{gather, Encoder, TextEncoder}; +use std::collections::HashMap; +use std::net::SocketAddr; +use tokio::net::TcpListener; +use tower_http::services::ServeDir; + +#[derive(Clone)] +struct AppState { + master_key: String, + config_sender: Sender, + config_api_enabled: bool, +} + +#[allow(unused_mut)] +pub async fn run_server(config: &APIUpstreamProvider, mut to_return: Sender) { + let app_state = AppState { + master_key: config.masterkey.clone(), + config_sender: to_return.clone(), + config_api_enabled: config.config_api_enabled.clone(), + }; + + let app = Router::new() + // .route("/{*wildcard}", get(senderror)) + // .route("/{*wildcard}", post(senderror)) + // .route("/{*wildcard}", put(senderror)) + // .route("/{*wildcard}", head(senderror)) + // .route("/{*wildcard}", delete(senderror)) + // .nest_service("/static", static_files) + .route("/conf", post(conf)) + .route("/metrics", get(metrics)) + .with_state(app_state); + + if let Some(value) = &config.tls_address { + let cf = OpenSSLConfig::from_pem_file(config.tls_certificate.clone().unwrap(), config.tls_key_file.clone().unwrap()).unwrap(); + let addr: SocketAddr = value.parse().expect("Unable to parse socket address"); + let tls_app = app.clone(); + tokio::spawn(async move { + if let Err(e) = axum_server::bind_openssl(addr, cf).serve(tls_app.into_make_service()).await { + eprintln!("TLS server failed: {}", e); + } + }); + info!("Starting the TLS API server on: {}", value); + } + + if let (Some(address), Some(folder)) = (&config.file_server_address, &config.file_server_folder) { + let static_files = ServeDir::new(folder); + let static_serve: Router = Router::new().fallback_service(static_files); + let static_listen = TcpListener::bind(address).await.unwrap(); + let _ = tokio::spawn(async move { axum::serve(static_listen, static_serve).await.unwrap() }); + } + + let listener = TcpListener::bind(config.address.clone()).await.unwrap(); + info!("Starting the API server on: {}", config.address); + axum::serve(listener, app).await.unwrap(); +} + +async fn conf(State(mut st): State, Query(params): Query>, content: String) -> impl IntoResponse { + if !st.config_api_enabled { + return Response::builder() + .status(StatusCode::FORBIDDEN) + .body(Body::from("Config remote API is disabled !\n")) + .unwrap(); + } + + if let Some(s) = params.get("key") { + if s.to_owned() == st.master_key { + if let Some(serverlist) = crate::utils::parceyaml::load_configuration(content.as_str(), "content").await { + st.config_sender.send(serverlist).await.unwrap(); + return Response::builder().status(StatusCode::OK).body(Body::from("Config, conf file, updated !\n")).unwrap(); + } else { + return Response::builder().status(StatusCode::BAD_GATEWAY).body(Body::from("Failed to parse config!\n")).unwrap(); + }; + } + } + Response::builder().status(StatusCode::FORBIDDEN).body(Body::from("Access Denied !\n")).unwrap() +} + +async fn metrics() -> impl IntoResponse { + let metric_families = gather(); + let encoder = TextEncoder::new(); + + let mut buffer = Vec::new(); + if let Err(e) = encoder.encode(&metric_families, &mut buffer) { + // encoding error fallback + return Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::from(format!("Failed to encode metrics: {}", e))) + .unwrap(); + } + + Response::builder() + .status(StatusCode::OK) + .header("Content-Type", encoder.format_type()) + .body(Body::from(buffer)) + .unwrap() +} + +// #[allow(dead_code)] +// async fn senderror() -> impl IntoResponse { +// Response::builder().status(StatusCode::BAD_GATEWAY).body(Body::from("No live upstream found!\n")).unwrap() +// } diff --git a/src/ja4_plus.rs b/src/ja4_plus.rs index 6c1f241..d5c4889 100644 --- a/src/ja4_plus.rs +++ b/src/ja4_plus.rs @@ -1,723 +1,723 @@ -use sha2::{Digest, Sha256}; -use hyper::HeaderMap; - -/// JA4T: TCP Fingerprint from TCP options -/// Official Format: {window_size}_{tcp_options}_{mss}_{window_scale} -/// Example: "65535_2-4-8-1-3_1460_7" -/// Reference: https://github.com/FoxIO-LLC/ja4/blob/main/zeek/ja4t/main.zeek -#[derive(Debug, Clone)] -pub struct Ja4tFingerprint { - pub fingerprint: String, - pub window_size: u16, - pub ttl: u16, - pub mss: u16, - pub window_scale: u8, - pub options: Vec, -} - -impl Ja4tFingerprint { - /// Generate JA4T fingerprint from TCP parameters - /// TCP options are represented by their kind numbers: - /// 0 = EOL, 1 = NOP, 2 = MSS, 3 = Window Scale, 4 = SACK Permitted, - /// 5 = SACK, 8 = Timestamps, etc. - pub fn from_tcp_data( - window_size: u16, - ttl: u16, - mss: u16, - window_scale: u8, - options: &[u8], - ) -> Self { - // Extract TCP option kinds - let mut option_kinds = Vec::new(); - let mut i = 0; - - while i < options.len() { - let kind = options[i]; - - match kind { - 0 => break, // EOL - 1 => { - // NOP - single byte - option_kinds.push(kind); - i += 1; - } - _ => { - // Options with length - if i + 1 < options.len() { - let len = options[i + 1] as usize; - option_kinds.push(kind); - i += len.max(2); - } else { - break; - } - } - } - } - - // Build fingerprint: window_size_options_mss_window_scale - // Official format from Zeek implementation - let options_str = option_kinds - .iter() - .map(|k| k.to_string()) - .collect::>() - .join("-"); - - let fingerprint = format!("{}_{}_{}_{}", - window_size, - if options_str.is_empty() { "0" } else { &options_str }, - mss, - window_scale - ); - - Self { - fingerprint, - window_size, - ttl, - mss, - window_scale, - options: option_kinds, - } - } - - /// Get the JA4T hash (first 12 characters of SHA-256) - pub fn hash(&self) -> String { - let digest = Sha256::digest(self.fingerprint.as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - } -} - -/// JA4H: HTTP Header Fingerprint -/// Official Format: {method}{version}{cookie}{referer}{header_count}{language}_{headers_hash}_{cookie_names_hash}_{cookie_values_hash} -/// Example: "ge11cr15enus_a1b2c3d4e5f6_123456789abc_def012345678" -/// Reference: https://github.com/FoxIO-LLC/ja4/blob/main/zeek/ja4h/main.zeek -#[derive(Debug, Clone)] -pub struct Ja4hFingerprint { - pub fingerprint: String, - pub method: String, - pub version: String, - pub has_cookie: bool, - pub has_referer: bool, - pub header_count: usize, - pub language: String, -} - -impl Ja4hFingerprint { - /// Generate JA4H fingerprint from HTTP request - pub fn from_http_request( - method: &str, - version: &str, - headers: &HeaderMap, - ) -> Self { - // Method: first 2 letters, lowercase (ge=GET, po=POST, etc.) - let method_str = method.to_lowercase(); - let method_code = if method_str.len() >= 2 { - &method_str[..2] - } else { - "un" // unknown - }; - - // Version: 10, 11, 20, 30 - let version_code = match version { - "HTTP/0.9" => "09", - "HTTP/1.0" => "10", - "HTTP/1.1" => "11", - "HTTP/2.0" | "HTTP/2" => "20", - "HTTP/3.0" | "HTTP/3" => "30", - _ => "00", - }; - - // Check if Cookie and Referer headers exist - let has_cookie = headers.contains_key("cookie"); - let cookie_flag = if has_cookie { "c" } else { "n" }; - - let has_referer = headers.contains_key("referer"); - let referer_flag = if has_referer { "r" } else { "n" }; - - // Extract language from Accept-Language header - let language = if let Some(lang_header) = headers.get("accept-language") { - if let Ok(lang_str) = lang_header.to_str() { - // Take first language, remove hyphens, lowercase, pad to 4 chars - let primary_lang = lang_str.split(',').next().unwrap_or(""); - let clean_lang = primary_lang - .split(';') - .next() - .unwrap_or("") - .replace('-', "") - .to_lowercase(); - - let mut lang_code = clean_lang.chars().take(4).collect::(); - while lang_code.len() < 4 { - lang_code.push('0'); - } - lang_code - } else { - "0000".to_string() - } - } else { - "0000".to_string() - }; - - // Collect header names (excluding Cookie and Referer) - let mut header_names: Vec = headers - .iter() - .filter_map(|(name, _)| { - let name_str = name.as_str().to_lowercase(); - if name_str == "cookie" || name_str == "referer" { - None - } else { - Some(name_str) - } - }) - .collect(); - - header_names.sort(); - let header_count = header_names.len().min(99); - let header_count_str = format!("{:02}", header_count); - - // Create header hash - let header_string = header_names.join(","); - let header_hash = if header_string.is_empty() { - "000000000000".to_string() - } else { - let digest = Sha256::digest(header_string.as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - }; - - // Parse cookies if they exist - let (cookie_names_hash, cookie_values_hash) = if let Some(cookie_value) = headers.get("cookie") { - if let Ok(cookie_str) = cookie_value.to_str() { - // Parse cookie pairs: name=value; name2=value2 - let mut cookie_names: Vec = Vec::new(); - let mut cookie_values: Vec = Vec::new(); - - for part in cookie_str.split(';') { - let trimmed = part.trim(); - if let Some((name, _value)) = trimmed.split_once('=') { - cookie_names.push(name.trim().to_string()); - cookie_values.push(trimmed.to_string()); // Full "name=value" - } - } - - // Sort separately - cookie_names.sort(); - cookie_values.sort(); - - // Hash separately - let names_hash = if cookie_names.is_empty() { - "000000000000".to_string() - } else { - let digest = Sha256::digest(cookie_names.join(",").as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - }; - - let values_hash = if cookie_values.is_empty() { - "000000000000".to_string() - } else { - let digest = Sha256::digest(cookie_values.join(",").as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - }; - - (names_hash, values_hash) - } else { - ("000000000000".to_string(), "000000000000".to_string()) - } - } else { - ("000000000000".to_string(), "000000000000".to_string()) - }; - - // Build fingerprint: {method}{version}{cookie}{referer}{count}{lang}_{headers}_{cookie_names}_{cookie_values} - let fingerprint = format!( - "{}{}{}{}{}{}_{}_{}_{}", - method_code, - version_code, - cookie_flag, - referer_flag, - header_count_str, - language, - header_hash, - cookie_names_hash, - cookie_values_hash - ); - - Self { - fingerprint, - method: method.to_string(), - version: version.to_string(), - has_cookie, - has_referer, - header_count, - language, - } - } -} - -/// JA4L: Latency Fingerprint -/// Official Format: {rtt_microseconds}_{ttl} -/// Measures round-trip time between SYN and SYNACK packets -/// Example: "12500_64" (12.5ms RTT, TTL 64) -/// Reference: https://github.com/FoxIO-LLC/ja4/blob/main/zeek/ja4l/main.zeek -#[derive(Debug, Clone)] -pub struct Ja4lMeasurement { - pub syn_time: Option, // Microseconds - pub synack_time: Option, // Microseconds - pub ack_time: Option, // Microseconds - pub ttl_client: Option, - pub ttl_server: Option, -} - -impl Ja4lMeasurement { - pub fn new() -> Self { - Self { - syn_time: None, - synack_time: None, - ack_time: None, - ttl_client: None, - ttl_server: None, - } - } - - /// Record SYN packet timestamp - pub fn set_syn(&mut self, timestamp_us: u64, ttl: u8) { - self.syn_time = Some(timestamp_us); - self.ttl_client = Some(ttl); - } - - /// Record SYNACK packet timestamp - pub fn set_synack(&mut self, timestamp_us: u64, ttl: u8) { - self.synack_time = Some(timestamp_us); - self.ttl_server = Some(ttl); - } - - /// Record ACK packet timestamp - pub fn set_ack(&mut self, timestamp_us: u64) { - self.ack_time = Some(timestamp_us); - } - - /// Generate JA4L client fingerprint - /// Format: {client_rtt_us}_{client_ttl} - /// RTT = (ACK - SYNACK) / 2 - pub fn fingerprint_client(&self) -> Option { - let synack = self.synack_time?; - let ack = self.ack_time?; - let ttl = self.ttl_client?; - - // Calculate client-side RTT (half of ACK-SYNACK time) - let rtt_us = (ack.saturating_sub(synack)) / 2; - - Some(format!("{}_{}", rtt_us, ttl)) - } - - /// Generate JA4L server fingerprint - /// Format: {server_rtt_us}_{server_ttl} - /// RTT = (SYNACK - SYN) / 2 - pub fn fingerprint_server(&self) -> Option { - let syn = self.syn_time?; - let synack = self.synack_time?; - let ttl = self.ttl_server?; - - // Calculate server-side RTT (half of SYNACK-SYN time) - let rtt_us = (synack.saturating_sub(syn)) / 2; - - Some(format!("{}_{}", rtt_us, ttl)) - } - - /// Legacy format for compatibility (if needed) - /// Returns both client and server measurements - pub fn fingerprint_combined(&self) -> Option { - let client = self.fingerprint_client()?; - let server = self.fingerprint_server()?; - - Some(format!("c:{},s:{}", client, server)) - } -} - -impl Default for Ja4lMeasurement { - fn default() -> Self { - Self::new() - } -} - -/// JA4S: TLS Server Response Fingerprint -/// Official Format: {proto}{version}{ext_count}{alpn}_{cipher}_{extensions_hash} -/// Example: "t130200_1301_a56c5b993250" -/// Reference: https://github.com/FoxIO-LLC/ja4/blob/main/zeek/ja4s/main.zeek -#[derive(Debug, Clone)] -pub struct Ja4sFingerprint { - pub fingerprint: String, - pub proto: String, - pub version: String, - pub cipher: u16, - pub extensions: Vec, - pub alpn: Option, -} - -impl Ja4sFingerprint { - /// Generate JA4S fingerprint from TLS ServerHello - /// - /// # Arguments - /// * `is_quic` - true if QUIC, false if TCP TLS - /// * `version` - TLS version (0x0304 for TLS 1.3, etc.) - /// * `cipher` - Cipher suite selected by server - /// * `extensions` - List of extension codes from ServerHello - /// * `alpn` - ALPN protocol selected (e.g., "h2", "http/1.1") - pub fn from_server_hello( - is_quic: bool, - version: u16, - cipher: u16, - extensions: &[u16], - alpn: Option<&str>, - ) -> Self { - // Proto: q=QUIC, t=TCP - let proto = if is_quic { "q" } else { "t" }; - - // Version mapping - let version_str = match version { - 0x0304 => "13", // TLS 1.3 - 0x0303 => "12", // TLS 1.2 - 0x0302 => "11", // TLS 1.1 - 0x0301 => "10", // TLS 1.0 - 0x0300 => "s3", // SSL 3.0 - 0x0002 => "s2", // SSL 2.0 - 0xfeff => "d1", // DTLS 1.0 - 0xfefd => "d2", // DTLS 1.2 - 0xfefc => "d3", // DTLS 1.3 - _ => "00", - }; - - // Extension count (max 99) - let ext_count = format!("{:02}", extensions.len().min(99)); - - // ALPN: first and last character - let alpn_code = if let Some(alpn_str) = alpn { - if alpn_str.is_empty() { - "00".to_string() - } else if alpn_str.len() == 1 { - let ch = alpn_str.chars().next().unwrap(); - format!("{}{}", ch, ch) - } else { - let first = alpn_str.chars().next().unwrap(); - let last = alpn_str.chars().last().unwrap(); - format!("{}{}", first, last) - } - } else { - "00".to_string() - }; - - // Build part A - let part_a = format!("{}{}{}{}", proto, version_str, ext_count, alpn_code); - - // Build part B (cipher in hex) - let part_b = format!("{:04x}", cipher); - - // Build part C (extensions hash) - let ext_strings: Vec = extensions.iter().map(|e| format!("{:04x}", e)).collect(); - let ext_string = ext_strings.join(","); - let part_c = if ext_string.is_empty() { - "000000000000".to_string() - } else { - let digest = Sha256::digest(ext_string.as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - }; - - let fingerprint = format!("{}_{}_{}",part_a, part_b, part_c); - - Self { - fingerprint, - proto: proto.to_string(), - version: version_str.to_string(), - cipher, - extensions: extensions.to_vec(), - alpn: alpn.map(|s| s.to_string()), - } - } - - /// Get raw (non-hashed) fingerprint - pub fn raw(&self) -> String { - let proto = &self.proto; - let version = &self.version; - let ext_count = format!("{:02}", self.extensions.len().min(99)); - let alpn_code = self.alpn.as_ref().map_or("00".to_string(), |a| { - if a.is_empty() { - "00".to_string() - } else if a.len() == 1 { - format!("{}{}", a, a) - } else { - format!("{}{}", a.chars().next().unwrap(), a.chars().last().unwrap()) - } - }); - - let part_a = format!("{}{}{}{}", proto, version, ext_count, alpn_code); - let part_b = format!("{:04x}", self.cipher); - let ext_strings: Vec = self.extensions.iter().map(|e| format!("{:04x}", e)).collect(); - let part_c = ext_strings.join(","); - - format!("{}_{}_{}", part_a, part_b, part_c) - } -} - -/// JA4X: X.509 Certificate Fingerprint -/// Official Format: {issuer_rdns_hash}_{subject_rdns_hash}_{extensions_hash} -/// Example: "aae71e8db6d7_b186095e22b6_c1a4f9e7d8b3" -/// Reference: https://github.com/FoxIO-LLC/ja4/blob/main/rust/ja4x/src/lib.rs -#[derive(Debug, Clone)] -pub struct Ja4xFingerprint { - pub fingerprint: String, - pub issuer_rdns: String, - pub subject_rdns: String, - pub extensions: String, -} - -impl Ja4xFingerprint { - /// Generate JA4X fingerprint from X.509 certificate attributes - /// - /// # Arguments - /// * `issuer_oids` - List of issuer RDN OIDs in hex (e.g., ["550406", "55040a"]) - /// * `subject_oids` - List of subject RDN OIDs in hex - /// * `extension_oids` - List of extension OIDs in hex - pub fn from_x509( - issuer_oids: &[String], - subject_oids: &[String], - extension_oids: &[String], - ) -> Self { - let issuer_rdns = issuer_oids.join(","); - let subject_rdns = subject_oids.join(","); - let extensions = extension_oids.join(","); - - let issuer_hash = if issuer_rdns.is_empty() { - "000000000000".to_string() - } else { - let digest = Sha256::digest(issuer_rdns.as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - }; - - let subject_hash = if subject_rdns.is_empty() { - "000000000000".to_string() - } else { - let digest = Sha256::digest(subject_rdns.as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - }; - - let extensions_hash = if extensions.is_empty() { - "000000000000".to_string() - } else { - let digest = Sha256::digest(extensions.as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - }; - - let fingerprint = format!("{}_{}_{}", issuer_hash, subject_hash, extensions_hash); - - Self { - fingerprint, - issuer_rdns, - subject_rdns, - extensions, - } - } - - /// Get raw (non-hashed) fingerprint - pub fn raw(&self) -> String { - format!("{}_{}_{}", self.issuer_rdns, self.subject_rdns, self.extensions) - } - - /// Helper to convert OID string to hex representation - /// Example: "2.5.4.3" -> "550403" - pub fn oid_to_hex(oid: &str) -> String { - let parts: Vec = oid.split('.').filter_map(|s| s.parse().ok()).collect(); - if parts.len() < 2 { - return String::new(); - } - - let mut result: Vec = vec![(parts[0] * 40 + parts[1]) as u8]; - - for &part in &parts[2..] { - let encoded = Self::encode_variable_length(part); - result.extend(encoded); - } - - result.iter().map(|b| format!("{:02x}", b)).collect::() - } - - /// Encode value as variable-length quantity (for OID encoding) - fn encode_variable_length(mut value: u32) -> Vec { - let mut output = Vec::new(); - let mut mask = 0x00; - - while value >= 0x80 { - output.insert(0, ((value & 0x7F) | mask) as u8); - value >>= 7; - mask = 0x80; - } - output.insert(0, (value | mask) as u8); - output - } -} - -#[cfg(test)] -mod tests { - use super::*; - use hyper::HeaderMap; - - #[test] - fn test_ja4t_fingerprint() { - let ja4t = Ja4tFingerprint::from_tcp_data( - 65535, // window_size - 64, // ttl - 1460, // mss - 7, // window_scale - &[2, 4, 5, 180, 4, 2, 8, 10], // TCP options (MSS, SACK, Timestamps) - ); - - assert_eq!(ja4t.window_size, 65535); - assert_eq!(ja4t.ttl, 64); - assert_eq!(ja4t.mss, 1460); - assert_eq!(ja4t.window_scale, 7); - // Official format: {window_size}_{options}_{mss}_{window_scale} - assert!(ja4t.fingerprint.starts_with("65535_")); - assert!(ja4t.fingerprint.contains("_1460_7")); - assert!(!ja4t.hash().is_empty()); - assert_eq!(ja4t.hash().len(), 12); - } - - #[test] - fn test_ja4h_fingerprint() { - let mut headers = HeaderMap::new(); - headers.insert("user-agent", "Mozilla/5.0".parse().unwrap()); - headers.insert("accept", "*/*".parse().unwrap()); - headers.insert("accept-language", "en-US,en;q=0.9".parse().unwrap()); - headers.insert("cookie", "session=abc123; id=xyz789".parse().unwrap()); - headers.insert("referer", "https://example.com".parse().unwrap()); - - let ja4h = Ja4hFingerprint::from_http_request( - "GET", - "HTTP/1.1", - &headers, - ); - - assert_eq!(ja4h.method, "GET"); - assert_eq!(ja4h.version, "HTTP/1.1"); - assert!(ja4h.has_cookie); - assert!(ja4h.has_referer); - assert_eq!(ja4h.language, "enus"); - // Official format: {method}{version}{cookie}{referer}{count}{lang}_{headers}_{cookie_names}_{cookie_values} - assert!(ja4h.fingerprint.starts_with("ge11cr")); - assert!(ja4h.fingerprint.contains("enus_")); - // Should have 4 parts separated by underscores - assert_eq!(ja4h.fingerprint.matches('_').count(), 3); - } - - #[test] - fn test_ja4l_measurement() { - let mut ja4l = Ja4lMeasurement::new(); - - // Simulate TCP handshake timing (in microseconds) - ja4l.set_syn(1000000, 64); // SYN at 1s, TTL 64 - ja4l.set_synack(1025000, 128); // SYNACK at 1.025s, TTL 128 (25ms later) - ja4l.set_ack(1050000); // ACK at 1.050s (25ms after SYNACK) - - // Client fingerprint: (ACK - SYNACK) / 2 = (1050000 - 1025000) / 2 = 12500μs - let client_fp = ja4l.fingerprint_client().unwrap(); - assert_eq!(client_fp, "12500_64"); - - // Server fingerprint: (SYNACK - SYN) / 2 = (1025000 - 1000000) / 2 = 12500μs - let server_fp = ja4l.fingerprint_server().unwrap(); - assert_eq!(server_fp, "12500_128"); - } - - #[test] - fn test_ja4s_fingerprint() { - // TLS 1.3 ServerHello with extensions - let ja4s = Ja4sFingerprint::from_server_hello( - false, // TCP (not QUIC) - 0x0304, // TLS 1.3 - 0x1301, // TLS_AES_128_GCM_SHA256 - &[0x002b, 0x0033], // supported_versions, key_share - Some("h2"), // ALPN - ); - - assert_eq!(ja4s.proto, "t"); - assert_eq!(ja4s.version, "13"); - assert_eq!(ja4s.cipher, 0x1301); - assert!(ja4s.fingerprint.starts_with("t1302h2_1301_")); - assert_eq!(ja4s.fingerprint.matches('_').count(), 2); - - // Verify raw format - let raw = ja4s.raw(); - assert!(raw.starts_with("t1302h2_1301_")); - assert!(raw.contains("002b,0033") || raw.contains("002b") && raw.contains("0033")); - } - - #[test] - fn test_ja4s_quic() { - // QUIC ServerHello - let ja4s = Ja4sFingerprint::from_server_hello( - true, // QUIC - 0x0304, // TLS 1.3 - 0x1302, // TLS_AES_256_GCM_SHA384 - &[0x002b], // supported_versions - Some("h3"), // HTTP/3 - ); - - assert_eq!(ja4s.proto, "q"); - assert!(ja4s.fingerprint.starts_with("q1301h3_1302_")); - } - - #[test] - fn test_ja4x_fingerprint() { - // X.509 certificate with common OIDs - let issuer_oids = vec![ - "550406".to_string(), // countryName - "55040a".to_string(), // organizationName - "550403".to_string(), // commonName - ]; - - let subject_oids = vec![ - "550406".to_string(), // countryName - "550403".to_string(), // commonName - ]; - - let extensions = vec![ - "551d0f".to_string(), // keyUsage - "551d25".to_string(), // extKeyUsage - "551d11".to_string(), // subjectAltName - ]; - - let ja4x = Ja4xFingerprint::from_x509(&issuer_oids, &subject_oids, &extensions); - - // Should have 3 parts separated by underscores - assert_eq!(ja4x.fingerprint.matches('_').count(), 2); - - // Each hash should be 12 characters - let parts: Vec<&str> = ja4x.fingerprint.split('_').collect(); - assert_eq!(parts.len(), 3); - assert_eq!(parts[0].len(), 12); - assert_eq!(parts[1].len(), 12); - assert_eq!(parts[2].len(), 12); - - // Verify raw format - let raw = ja4x.raw(); - assert!(raw.contains("550406,55040a,550403")); - assert!(raw.contains("551d0f,551d25,551d11")); - } - - #[test] - fn test_ja4x_oid_conversion() { - // Test OID to hex conversion - assert_eq!(Ja4xFingerprint::oid_to_hex("2.5.4.3"), "550403"); - assert_eq!(Ja4xFingerprint::oid_to_hex("2.5.4.6"), "550406"); - assert_eq!(Ja4xFingerprint::oid_to_hex("2.5.4.10"), "55040a"); - assert_eq!(Ja4xFingerprint::oid_to_hex("2.5.29.15"), "551d0f"); - - // Invalid OID - assert_eq!(Ja4xFingerprint::oid_to_hex("2"), ""); - assert_eq!(Ja4xFingerprint::oid_to_hex(""), ""); - } -} - +use sha2::{Digest, Sha256}; +use hyper::HeaderMap; + +/// JA4T: TCP Fingerprint from TCP options +/// Official Format: {window_size}_{tcp_options}_{mss}_{window_scale} +/// Example: "65535_2-4-8-1-3_1460_7" +/// Reference: https://github.com/FoxIO-LLC/ja4/blob/main/zeek/ja4t/main.zeek +#[derive(Debug, Clone)] +pub struct Ja4tFingerprint { + pub fingerprint: String, + pub window_size: u16, + pub ttl: u16, + pub mss: u16, + pub window_scale: u8, + pub options: Vec, +} + +impl Ja4tFingerprint { + /// Generate JA4T fingerprint from TCP parameters + /// TCP options are represented by their kind numbers: + /// 0 = EOL, 1 = NOP, 2 = MSS, 3 = Window Scale, 4 = SACK Permitted, + /// 5 = SACK, 8 = Timestamps, etc. + pub fn from_tcp_data( + window_size: u16, + ttl: u16, + mss: u16, + window_scale: u8, + options: &[u8], + ) -> Self { + // Extract TCP option kinds + let mut option_kinds = Vec::new(); + let mut i = 0; + + while i < options.len() { + let kind = options[i]; + + match kind { + 0 => break, // EOL + 1 => { + // NOP - single byte + option_kinds.push(kind); + i += 1; + } + _ => { + // Options with length + if i + 1 < options.len() { + let len = options[i + 1] as usize; + option_kinds.push(kind); + i += len.max(2); + } else { + break; + } + } + } + } + + // Build fingerprint: window_size_options_mss_window_scale + // Official format from Zeek implementation + let options_str = option_kinds + .iter() + .map(|k| k.to_string()) + .collect::>() + .join("-"); + + let fingerprint = format!("{}_{}_{}_{}", + window_size, + if options_str.is_empty() { "0" } else { &options_str }, + mss, + window_scale + ); + + Self { + fingerprint, + window_size, + ttl, + mss, + window_scale, + options: option_kinds, + } + } + + /// Get the JA4T hash (first 12 characters of SHA-256) + pub fn hash(&self) -> String { + let digest = Sha256::digest(self.fingerprint.as_bytes()); + let hex = format!("{:x}", digest); + hex[..12].to_string() + } +} + +/// JA4H: HTTP Header Fingerprint +/// Official Format: {method}{version}{cookie}{referer}{header_count}{language}_{headers_hash}_{cookie_names_hash}_{cookie_values_hash} +/// Example: "ge11cr15enus_a1b2c3d4e5f6_123456789abc_def012345678" +/// Reference: https://github.com/FoxIO-LLC/ja4/blob/main/zeek/ja4h/main.zeek +#[derive(Debug, Clone)] +pub struct Ja4hFingerprint { + pub fingerprint: String, + pub method: String, + pub version: String, + pub has_cookie: bool, + pub has_referer: bool, + pub header_count: usize, + pub language: String, +} + +impl Ja4hFingerprint { + /// Generate JA4H fingerprint from HTTP request + pub fn from_http_request( + method: &str, + version: &str, + headers: &HeaderMap, + ) -> Self { + // Method: first 2 letters, lowercase (ge=GET, po=POST, etc.) + let method_str = method.to_lowercase(); + let method_code = if method_str.len() >= 2 { + &method_str[..2] + } else { + "un" // unknown + }; + + // Version: 10, 11, 20, 30 + let version_code = match version { + "HTTP/0.9" => "09", + "HTTP/1.0" => "10", + "HTTP/1.1" => "11", + "HTTP/2.0" | "HTTP/2" => "20", + "HTTP/3.0" | "HTTP/3" => "30", + _ => "00", + }; + + // Check if Cookie and Referer headers exist + let has_cookie = headers.contains_key("cookie"); + let cookie_flag = if has_cookie { "c" } else { "n" }; + + let has_referer = headers.contains_key("referer"); + let referer_flag = if has_referer { "r" } else { "n" }; + + // Extract language from Accept-Language header + let language = if let Some(lang_header) = headers.get("accept-language") { + if let Ok(lang_str) = lang_header.to_str() { + // Take first language, remove hyphens, lowercase, pad to 4 chars + let primary_lang = lang_str.split(',').next().unwrap_or(""); + let clean_lang = primary_lang + .split(';') + .next() + .unwrap_or("") + .replace('-', "") + .to_lowercase(); + + let mut lang_code = clean_lang.chars().take(4).collect::(); + while lang_code.len() < 4 { + lang_code.push('0'); + } + lang_code + } else { + "0000".to_string() + } + } else { + "0000".to_string() + }; + + // Collect header names (excluding Cookie and Referer) + let mut header_names: Vec = headers + .iter() + .filter_map(|(name, _)| { + let name_str = name.as_str().to_lowercase(); + if name_str == "cookie" || name_str == "referer" { + None + } else { + Some(name_str) + } + }) + .collect(); + + header_names.sort(); + let header_count = header_names.len().min(99); + let header_count_str = format!("{:02}", header_count); + + // Create header hash + let header_string = header_names.join(","); + let header_hash = if header_string.is_empty() { + "000000000000".to_string() + } else { + let digest = Sha256::digest(header_string.as_bytes()); + let hex = format!("{:x}", digest); + hex[..12].to_string() + }; + + // Parse cookies if they exist + let (cookie_names_hash, cookie_values_hash) = if let Some(cookie_value) = headers.get("cookie") { + if let Ok(cookie_str) = cookie_value.to_str() { + // Parse cookie pairs: name=value; name2=value2 + let mut cookie_names: Vec = Vec::new(); + let mut cookie_values: Vec = Vec::new(); + + for part in cookie_str.split(';') { + let trimmed = part.trim(); + if let Some((name, _value)) = trimmed.split_once('=') { + cookie_names.push(name.trim().to_string()); + cookie_values.push(trimmed.to_string()); // Full "name=value" + } + } + + // Sort separately + cookie_names.sort(); + cookie_values.sort(); + + // Hash separately + let names_hash = if cookie_names.is_empty() { + "000000000000".to_string() + } else { + let digest = Sha256::digest(cookie_names.join(",").as_bytes()); + let hex = format!("{:x}", digest); + hex[..12].to_string() + }; + + let values_hash = if cookie_values.is_empty() { + "000000000000".to_string() + } else { + let digest = Sha256::digest(cookie_values.join(",").as_bytes()); + let hex = format!("{:x}", digest); + hex[..12].to_string() + }; + + (names_hash, values_hash) + } else { + ("000000000000".to_string(), "000000000000".to_string()) + } + } else { + ("000000000000".to_string(), "000000000000".to_string()) + }; + + // Build fingerprint: {method}{version}{cookie}{referer}{count}{lang}_{headers}_{cookie_names}_{cookie_values} + let fingerprint = format!( + "{}{}{}{}{}{}_{}_{}_{}", + method_code, + version_code, + cookie_flag, + referer_flag, + header_count_str, + language, + header_hash, + cookie_names_hash, + cookie_values_hash + ); + + Self { + fingerprint, + method: method.to_string(), + version: version.to_string(), + has_cookie, + has_referer, + header_count, + language, + } + } +} + +/// JA4L: Latency Fingerprint +/// Official Format: {rtt_microseconds}_{ttl} +/// Measures round-trip time between SYN and SYNACK packets +/// Example: "12500_64" (12.5ms RTT, TTL 64) +/// Reference: https://github.com/FoxIO-LLC/ja4/blob/main/zeek/ja4l/main.zeek +#[derive(Debug, Clone)] +pub struct Ja4lMeasurement { + pub syn_time: Option, // Microseconds + pub synack_time: Option, // Microseconds + pub ack_time: Option, // Microseconds + pub ttl_client: Option, + pub ttl_server: Option, +} + +impl Ja4lMeasurement { + pub fn new() -> Self { + Self { + syn_time: None, + synack_time: None, + ack_time: None, + ttl_client: None, + ttl_server: None, + } + } + + /// Record SYN packet timestamp + pub fn set_syn(&mut self, timestamp_us: u64, ttl: u8) { + self.syn_time = Some(timestamp_us); + self.ttl_client = Some(ttl); + } + + /// Record SYNACK packet timestamp + pub fn set_synack(&mut self, timestamp_us: u64, ttl: u8) { + self.synack_time = Some(timestamp_us); + self.ttl_server = Some(ttl); + } + + /// Record ACK packet timestamp + pub fn set_ack(&mut self, timestamp_us: u64) { + self.ack_time = Some(timestamp_us); + } + + /// Generate JA4L client fingerprint + /// Format: {client_rtt_us}_{client_ttl} + /// RTT = (ACK - SYNACK) / 2 + pub fn fingerprint_client(&self) -> Option { + let synack = self.synack_time?; + let ack = self.ack_time?; + let ttl = self.ttl_client?; + + // Calculate client-side RTT (half of ACK-SYNACK time) + let rtt_us = (ack.saturating_sub(synack)) / 2; + + Some(format!("{}_{}", rtt_us, ttl)) + } + + /// Generate JA4L server fingerprint + /// Format: {server_rtt_us}_{server_ttl} + /// RTT = (SYNACK - SYN) / 2 + pub fn fingerprint_server(&self) -> Option { + let syn = self.syn_time?; + let synack = self.synack_time?; + let ttl = self.ttl_server?; + + // Calculate server-side RTT (half of SYNACK-SYN time) + let rtt_us = (synack.saturating_sub(syn)) / 2; + + Some(format!("{}_{}", rtt_us, ttl)) + } + + /// Legacy format for compatibility (if needed) + /// Returns both client and server measurements + pub fn fingerprint_combined(&self) -> Option { + let client = self.fingerprint_client()?; + let server = self.fingerprint_server()?; + + Some(format!("c:{},s:{}", client, server)) + } +} + +impl Default for Ja4lMeasurement { + fn default() -> Self { + Self::new() + } +} + +/// JA4S: TLS Server Response Fingerprint +/// Official Format: {proto}{version}{ext_count}{alpn}_{cipher}_{extensions_hash} +/// Example: "t130200_1301_a56c5b993250" +/// Reference: https://github.com/FoxIO-LLC/ja4/blob/main/zeek/ja4s/main.zeek +#[derive(Debug, Clone)] +pub struct Ja4sFingerprint { + pub fingerprint: String, + pub proto: String, + pub version: String, + pub cipher: u16, + pub extensions: Vec, + pub alpn: Option, +} + +impl Ja4sFingerprint { + /// Generate JA4S fingerprint from TLS ServerHello + /// + /// # Arguments + /// * `is_quic` - true if QUIC, false if TCP TLS + /// * `version` - TLS version (0x0304 for TLS 1.3, etc.) + /// * `cipher` - Cipher suite selected by server + /// * `extensions` - List of extension codes from ServerHello + /// * `alpn` - ALPN protocol selected (e.g., "h2", "http/1.1") + pub fn from_server_hello( + is_quic: bool, + version: u16, + cipher: u16, + extensions: &[u16], + alpn: Option<&str>, + ) -> Self { + // Proto: q=QUIC, t=TCP + let proto = if is_quic { "q" } else { "t" }; + + // Version mapping + let version_str = match version { + 0x0304 => "13", // TLS 1.3 + 0x0303 => "12", // TLS 1.2 + 0x0302 => "11", // TLS 1.1 + 0x0301 => "10", // TLS 1.0 + 0x0300 => "s3", // SSL 3.0 + 0x0002 => "s2", // SSL 2.0 + 0xfeff => "d1", // DTLS 1.0 + 0xfefd => "d2", // DTLS 1.2 + 0xfefc => "d3", // DTLS 1.3 + _ => "00", + }; + + // Extension count (max 99) + let ext_count = format!("{:02}", extensions.len().min(99)); + + // ALPN: first and last character + let alpn_code = if let Some(alpn_str) = alpn { + if alpn_str.is_empty() { + "00".to_string() + } else if alpn_str.len() == 1 { + let ch = alpn_str.chars().next().unwrap(); + format!("{}{}", ch, ch) + } else { + let first = alpn_str.chars().next().unwrap(); + let last = alpn_str.chars().last().unwrap(); + format!("{}{}", first, last) + } + } else { + "00".to_string() + }; + + // Build part A + let part_a = format!("{}{}{}{}", proto, version_str, ext_count, alpn_code); + + // Build part B (cipher in hex) + let part_b = format!("{:04x}", cipher); + + // Build part C (extensions hash) + let ext_strings: Vec = extensions.iter().map(|e| format!("{:04x}", e)).collect(); + let ext_string = ext_strings.join(","); + let part_c = if ext_string.is_empty() { + "000000000000".to_string() + } else { + let digest = Sha256::digest(ext_string.as_bytes()); + let hex = format!("{:x}", digest); + hex[..12].to_string() + }; + + let fingerprint = format!("{}_{}_{}",part_a, part_b, part_c); + + Self { + fingerprint, + proto: proto.to_string(), + version: version_str.to_string(), + cipher, + extensions: extensions.to_vec(), + alpn: alpn.map(|s| s.to_string()), + } + } + + /// Get raw (non-hashed) fingerprint + pub fn raw(&self) -> String { + let proto = &self.proto; + let version = &self.version; + let ext_count = format!("{:02}", self.extensions.len().min(99)); + let alpn_code = self.alpn.as_ref().map_or("00".to_string(), |a| { + if a.is_empty() { + "00".to_string() + } else if a.len() == 1 { + format!("{}{}", a, a) + } else { + format!("{}{}", a.chars().next().unwrap(), a.chars().last().unwrap()) + } + }); + + let part_a = format!("{}{}{}{}", proto, version, ext_count, alpn_code); + let part_b = format!("{:04x}", self.cipher); + let ext_strings: Vec = self.extensions.iter().map(|e| format!("{:04x}", e)).collect(); + let part_c = ext_strings.join(","); + + format!("{}_{}_{}", part_a, part_b, part_c) + } +} + +/// JA4X: X.509 Certificate Fingerprint +/// Official Format: {issuer_rdns_hash}_{subject_rdns_hash}_{extensions_hash} +/// Example: "aae71e8db6d7_b186095e22b6_c1a4f9e7d8b3" +/// Reference: https://github.com/FoxIO-LLC/ja4/blob/main/rust/ja4x/src/lib.rs +#[derive(Debug, Clone)] +pub struct Ja4xFingerprint { + pub fingerprint: String, + pub issuer_rdns: String, + pub subject_rdns: String, + pub extensions: String, +} + +impl Ja4xFingerprint { + /// Generate JA4X fingerprint from X.509 certificate attributes + /// + /// # Arguments + /// * `issuer_oids` - List of issuer RDN OIDs in hex (e.g., ["550406", "55040a"]) + /// * `subject_oids` - List of subject RDN OIDs in hex + /// * `extension_oids` - List of extension OIDs in hex + pub fn from_x509( + issuer_oids: &[String], + subject_oids: &[String], + extension_oids: &[String], + ) -> Self { + let issuer_rdns = issuer_oids.join(","); + let subject_rdns = subject_oids.join(","); + let extensions = extension_oids.join(","); + + let issuer_hash = if issuer_rdns.is_empty() { + "000000000000".to_string() + } else { + let digest = Sha256::digest(issuer_rdns.as_bytes()); + let hex = format!("{:x}", digest); + hex[..12].to_string() + }; + + let subject_hash = if subject_rdns.is_empty() { + "000000000000".to_string() + } else { + let digest = Sha256::digest(subject_rdns.as_bytes()); + let hex = format!("{:x}", digest); + hex[..12].to_string() + }; + + let extensions_hash = if extensions.is_empty() { + "000000000000".to_string() + } else { + let digest = Sha256::digest(extensions.as_bytes()); + let hex = format!("{:x}", digest); + hex[..12].to_string() + }; + + let fingerprint = format!("{}_{}_{}", issuer_hash, subject_hash, extensions_hash); + + Self { + fingerprint, + issuer_rdns, + subject_rdns, + extensions, + } + } + + /// Get raw (non-hashed) fingerprint + pub fn raw(&self) -> String { + format!("{}_{}_{}", self.issuer_rdns, self.subject_rdns, self.extensions) + } + + /// Helper to convert OID string to hex representation + /// Example: "2.5.4.3" -> "550403" + pub fn oid_to_hex(oid: &str) -> String { + let parts: Vec = oid.split('.').filter_map(|s| s.parse().ok()).collect(); + if parts.len() < 2 { + return String::new(); + } + + let mut result: Vec = vec![(parts[0] * 40 + parts[1]) as u8]; + + for &part in &parts[2..] { + let encoded = Self::encode_variable_length(part); + result.extend(encoded); + } + + result.iter().map(|b| format!("{:02x}", b)).collect::() + } + + /// Encode value as variable-length quantity (for OID encoding) + fn encode_variable_length(mut value: u32) -> Vec { + let mut output = Vec::new(); + let mut mask = 0x00; + + while value >= 0x80 { + output.insert(0, ((value & 0x7F) | mask) as u8); + value >>= 7; + mask = 0x80; + } + output.insert(0, (value | mask) as u8); + output + } +} + +#[cfg(test)] +mod tests { + use super::*; + use hyper::HeaderMap; + + #[test] + fn test_ja4t_fingerprint() { + let ja4t = Ja4tFingerprint::from_tcp_data( + 65535, // window_size + 64, // ttl + 1460, // mss + 7, // window_scale + &[2, 4, 5, 180, 4, 2, 8, 10], // TCP options (MSS, SACK, Timestamps) + ); + + assert_eq!(ja4t.window_size, 65535); + assert_eq!(ja4t.ttl, 64); + assert_eq!(ja4t.mss, 1460); + assert_eq!(ja4t.window_scale, 7); + // Official format: {window_size}_{options}_{mss}_{window_scale} + assert!(ja4t.fingerprint.starts_with("65535_")); + assert!(ja4t.fingerprint.contains("_1460_7")); + assert!(!ja4t.hash().is_empty()); + assert_eq!(ja4t.hash().len(), 12); + } + + #[test] + fn test_ja4h_fingerprint() { + let mut headers = HeaderMap::new(); + headers.insert("user-agent", "Mozilla/5.0".parse().unwrap()); + headers.insert("accept", "*/*".parse().unwrap()); + headers.insert("accept-language", "en-US,en;q=0.9".parse().unwrap()); + headers.insert("cookie", "session=abc123; id=xyz789".parse().unwrap()); + headers.insert("referer", "https://example.com".parse().unwrap()); + + let ja4h = Ja4hFingerprint::from_http_request( + "GET", + "HTTP/1.1", + &headers, + ); + + assert_eq!(ja4h.method, "GET"); + assert_eq!(ja4h.version, "HTTP/1.1"); + assert!(ja4h.has_cookie); + assert!(ja4h.has_referer); + assert_eq!(ja4h.language, "enus"); + // Official format: {method}{version}{cookie}{referer}{count}{lang}_{headers}_{cookie_names}_{cookie_values} + assert!(ja4h.fingerprint.starts_with("ge11cr")); + assert!(ja4h.fingerprint.contains("enus_")); + // Should have 4 parts separated by underscores + assert_eq!(ja4h.fingerprint.matches('_').count(), 3); + } + + #[test] + fn test_ja4l_measurement() { + let mut ja4l = Ja4lMeasurement::new(); + + // Simulate TCP handshake timing (in microseconds) + ja4l.set_syn(1000000, 64); // SYN at 1s, TTL 64 + ja4l.set_synack(1025000, 128); // SYNACK at 1.025s, TTL 128 (25ms later) + ja4l.set_ack(1050000); // ACK at 1.050s (25ms after SYNACK) + + // Client fingerprint: (ACK - SYNACK) / 2 = (1050000 - 1025000) / 2 = 12500μs + let client_fp = ja4l.fingerprint_client().unwrap(); + assert_eq!(client_fp, "12500_64"); + + // Server fingerprint: (SYNACK - SYN) / 2 = (1025000 - 1000000) / 2 = 12500μs + let server_fp = ja4l.fingerprint_server().unwrap(); + assert_eq!(server_fp, "12500_128"); + } + + #[test] + fn test_ja4s_fingerprint() { + // TLS 1.3 ServerHello with extensions + let ja4s = Ja4sFingerprint::from_server_hello( + false, // TCP (not QUIC) + 0x0304, // TLS 1.3 + 0x1301, // TLS_AES_128_GCM_SHA256 + &[0x002b, 0x0033], // supported_versions, key_share + Some("h2"), // ALPN + ); + + assert_eq!(ja4s.proto, "t"); + assert_eq!(ja4s.version, "13"); + assert_eq!(ja4s.cipher, 0x1301); + assert!(ja4s.fingerprint.starts_with("t1302h2_1301_")); + assert_eq!(ja4s.fingerprint.matches('_').count(), 2); + + // Verify raw format + let raw = ja4s.raw(); + assert!(raw.starts_with("t1302h2_1301_")); + assert!(raw.contains("002b,0033") || raw.contains("002b") && raw.contains("0033")); + } + + #[test] + fn test_ja4s_quic() { + // QUIC ServerHello + let ja4s = Ja4sFingerprint::from_server_hello( + true, // QUIC + 0x0304, // TLS 1.3 + 0x1302, // TLS_AES_256_GCM_SHA384 + &[0x002b], // supported_versions + Some("h3"), // HTTP/3 + ); + + assert_eq!(ja4s.proto, "q"); + assert!(ja4s.fingerprint.starts_with("q1301h3_1302_")); + } + + #[test] + fn test_ja4x_fingerprint() { + // X.509 certificate with common OIDs + let issuer_oids = vec![ + "550406".to_string(), // countryName + "55040a".to_string(), // organizationName + "550403".to_string(), // commonName + ]; + + let subject_oids = vec![ + "550406".to_string(), // countryName + "550403".to_string(), // commonName + ]; + + let extensions = vec![ + "551d0f".to_string(), // keyUsage + "551d25".to_string(), // extKeyUsage + "551d11".to_string(), // subjectAltName + ]; + + let ja4x = Ja4xFingerprint::from_x509(&issuer_oids, &subject_oids, &extensions); + + // Should have 3 parts separated by underscores + assert_eq!(ja4x.fingerprint.matches('_').count(), 2); + + // Each hash should be 12 characters + let parts: Vec<&str> = ja4x.fingerprint.split('_').collect(); + assert_eq!(parts.len(), 3); + assert_eq!(parts[0].len(), 12); + assert_eq!(parts[1].len(), 12); + assert_eq!(parts[2].len(), 12); + + // Verify raw format + let raw = ja4x.raw(); + assert!(raw.contains("550406,55040a,550403")); + assert!(raw.contains("551d0f,551d25,551d11")); + } + + #[test] + fn test_ja4x_oid_conversion() { + // Test OID to hex conversion + assert_eq!(Ja4xFingerprint::oid_to_hex("2.5.4.3"), "550403"); + assert_eq!(Ja4xFingerprint::oid_to_hex("2.5.4.6"), "550406"); + assert_eq!(Ja4xFingerprint::oid_to_hex("2.5.4.10"), "55040a"); + assert_eq!(Ja4xFingerprint::oid_to_hex("2.5.29.15"), "551d0f"); + + // Invalid OID + assert_eq!(Ja4xFingerprint::oid_to_hex("2"), ""); + assert_eq!(Ja4xFingerprint::oid_to_hex(""), ""); + } +} + diff --git a/src/main.rs b/src/main.rs index 983c2bf..208b78e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,1398 +1,1398 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::str::FromStr; -use std::fs::File; -#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] -use std::mem::MaybeUninit; - -use anyhow::{Context, Result}; -use clap::Parser; -use daemonize::Daemonize; -#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] -use libbpf_rs::skel::{OpenSkel, SkelBuilder}; -#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] -use nix::net::if_::if_nametoindex; -use chrono::Utc; -use sha2::{Digest, Sha256}; - -#[global_allocator] -static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; - -pub mod access_log; -pub mod access_rules; -pub mod agent_status; -pub mod app_state; -#[cfg(feature = "proxy")] -pub mod captcha_server; -pub mod cli; -pub mod content_scanning; -#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] -pub mod firewall; -#[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] -#[path = "firewall_noop.rs"] -pub mod firewall; -pub mod http_client; -pub mod waf; -pub mod threat; -pub mod redis; -#[cfg(feature = "proxy")] -pub mod proxy_protocol; -pub mod authcheck; -#[cfg(feature = "proxy")] -pub mod http_proxy; -#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] -pub mod bpf { - // Include the skeleton generated by build.rs into OUT_DIR at compile time - include!(concat!(env!("OUT_DIR"), "/filter.skel.rs")); -} -#[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] -#[path = "bpf_stub.rs"] -pub mod bpf; - -#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] -pub mod bpf_stats; -#[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] -#[path = "bpf_stats_noop.rs"] -pub mod bpf_stats; -pub mod ja4_plus; -pub mod utils; -pub mod worker; -#[cfg(feature = "proxy")] -pub mod acme; - -use tokio::signal; -use tokio::sync::watch; -#[cfg(feature = "proxy")] -use log::{error, info}; -use log::warn; - -use crate::app_state::AppState; -use crate::bpf_stats::BpfStatsCollector; -use crate::utils::tcp_fingerprint::TcpFingerprintCollector; -use crate::utils::tcp_fingerprint::TcpFingerprintConfig; -use crate::cli::{Args, Config}; -use crate::waf::wirefilter::init_config; -use crate::content_scanning::{init_content_scanner, ContentScanningConfig}; -#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] -use crate::utils::bpf_utils::bpf_attach_to_xdp; -#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] -use crate::utils::bpf_utils::bpf_detach_from_xdp; - -use crate::access_log::LogSenderConfig; -use crate::worker::log::set_log_sender_config; -use crate::worker::agent_status::AgentStatusWorker; -use crate::agent_status::AgentStatusIdentity; -use crate::authcheck::validate_api_key; -use crate::http_client::init_global_client; -use crate::waf::actions::captcha::{CaptchaConfig, CaptchaProvider, init_captcha_client, start_cache_cleanup_task}; - -fn main() -> Result<()> { - // Initialize rustls crypto provider early (must be done before any rustls operations) - rustls::crypto::ring::default_provider() - .install_default() - .map_err(|e| anyhow::anyhow!("Failed to install rustls crypto provider: {:?}", e))?; - - let args = Args::parse(); - // Handle clear certificate command (runs before loading full config) - if let Some(certificate_name) = &args.clear_certificate { - #[cfg(feature = "proxy")] - { - // Initialize minimal runtime for async operations - let rt = tokio::runtime::Runtime::new() - .context("Failed to create tokio runtime")?; - - // Load minimal config for Redis connection - let config = Config::load_from_args(&args) - .context("Failed to load configuration")?; - - // Initialize Redis if configured - if !config.redis.url.is_empty() { - rt.block_on(crate::redis::RedisManager::init( - &config.redis.url, - config.redis.prefix.clone(), - config.redis.ssl.as_ref(), - )) - .context("Failed to initialize Redis manager")?; - } - - // Get certificate path from config - let certificate_path = config - .pingora - .proxy_certificates - .clone() - .unwrap_or_else(|| "/etc/synapse/certs".to_string()); - - // Clear the certificate - rt.block_on(crate::worker::certificate::clear_certificate( - certificate_name, - &certificate_path, - ))?; - - return Ok(()); - } - - #[cfg(not(feature = "proxy"))] - { - let _ = certificate_name; - return Err(anyhow::anyhow!( - "clear-certificate is not available in agent-only builds" - )); - } - } - - // API key is optional - allow running in local mode without it - - // Load configuration - let config = Config::load_from_args(&args) - .context("Failed to load configuration")?; - - if config.mode == "proxy" && !cfg!(feature = "proxy") { - return Err(anyhow::anyhow!( - "proxy mode is not supported in agent-only builds (build with the `proxy` feature)" - )); - } - - // Handle daemonization before starting tokio runtime - if config.daemon.enabled { - let stdout = File::create(&config.daemon.stdout) - .with_context(|| format!("Failed to create stdout file: {}", config.daemon.stdout))?; - let stderr = File::create(&config.daemon.stderr) - .with_context(|| format!("Failed to create stderr file: {}", config.daemon.stderr))?; - - let mut daemonize = Daemonize::new() - .pid_file(&config.daemon.pid_file) - .chown_pid_file(config.daemon.chown_pid_file) - .working_directory(&config.daemon.working_directory) - .stdout(stdout) - .stderr(stderr); - - if let Some(user) = &config.daemon.user { - daemonize = daemonize.user(user.as_str()); - } - - if let Some(group) = &config.daemon.group { - daemonize = daemonize.group(group.as_str()); - } - - match daemonize.start() { - Ok(_) => { - // We're now in the daemon process, continue with application startup - } - Err(e) => { - eprintln!("Failed to daemonize: {}", e); - return Err(anyhow::anyhow!("Daemonization failed: {}", e)); - } - } - } - - // Set RUST_LOG environment variable from config so other modules can use it - let log_level = if !config.logging.level.is_empty() { - config.logging.level.to_lowercase() - } else { - match args.log_level { - crate::cli::LogLevel::Error => "error", - crate::cli::LogLevel::Warn => "warn", - crate::cli::LogLevel::Info => "info", - crate::cli::LogLevel::Debug => "debug", - crate::cli::LogLevel::Trace => "trace", - }.to_string() - }; - unsafe { - std::env::set_var("RUST_LOG", &log_level); - } - - // Initialize logger using config level (CLI overrides if provided explicitly) - // Note: env_logger writes to stderr by default, which is standard practice - { - use env_logger::Env; - let mut builder = env_logger::Builder::from_env(Env::default().default_filter_or("info")); - - // Use log level from config, or CLI if explicitly set - let level_filter = match log_level.as_str() { - "error" => log::LevelFilter::Error, - "warn" => log::LevelFilter::Warn, - "info" => log::LevelFilter::Info, - "debug" => log::LevelFilter::Debug, - "trace" => log::LevelFilter::Trace, - _ => args.log_level.to_level_filter(), - }; - builder.filter_level(level_filter); - builder.format_timestamp_secs(); - - // In daemon mode, write to stdout instead of stderr for better log separation - if config.daemon.enabled { - builder.target(env_logger::Target::Stdout); - } - - builder.try_init().ok(); - } - - // Start the tokio runtime and run the async application - let runtime = match config.worker_threads { - // Explicit config: use specified thread count (0 = single-threaded) - Some(0) => { - tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()? - } - Some(threads) => { - tokio::runtime::Builder::new_multi_thread() - .worker_threads(threads) - .enable_all() - .build()? - } - None if config.mode == "agent" => { - tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()? - } - None => { - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build()? - } - }; - runtime.block_on(async_main(args, config)) -} - -#[allow(clippy::too_many_lines)] -async fn async_main(args: Args, config: Config) -> Result<()> { - let started_at = Utc::now(); - - if config.daemon.enabled { - log::info!("Running in daemon mode (PID file: {})", config.daemon.pid_file); - } - - // Initialize global HTTP client with keepalive configuration - if let Err(e) = init_global_client() { - log::warn!("Failed to initialize global HTTP client: {}", e); - } - - // Track enabled features for startup summary - let is_agent_mode = config.mode == "agent"; - let mut waf_enabled = false; - let mut threat_client_enabled = false; - let mut captcha_client_enabled = false; - - - let iface_names: Vec = if !config.network.ifaces.is_empty() { - config.network.ifaces.clone() - } else { - vec![config.network.iface.clone()] - }; - - use crate::firewall::{FirewallBackend, FirewallMode, NftablesFirewall, IptablesFirewall}; - use std::sync::Mutex; - - #[allow(unused_mut)] - let mut skels: Vec>> = Vec::new(); - #[allow(unused_mut)] - let mut ifindices: Vec = Vec::new(); - let mut firewall_backend = FirewallBackend::None; - let mut nftables_firewall: Option>> = None; - let mut iptables_firewall: Option>> = None; - let firewall_mode = config.network.firewall_mode; - - // Track XDP modes per interface for startup summary - #[allow(unused_mut)] - let mut xdp_modes: Vec<(&str, &str)> = Vec::new(); - - if config.network.disable_xdp { - log::warn!("XDP disabled by config, will use nftables fallback"); - } else { - #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] - { - // Suppress libbpf output - errors will be handled gracefully with fallback to nftables/iptables - libbpf_rs::set_print(None); - - for iface in &iface_names { - let boxed_open: Box> = Box::new(MaybeUninit::uninit()); - let open_object: &'static mut MaybeUninit = Box::leak(boxed_open); - let skel_builder = bpf::FilterSkelBuilder::default(); - match skel_builder.open(open_object).and_then(|o| o.load()) { - Ok(mut skel) => { - let ifindex = match if_nametoindex(iface.as_str()) { - Ok(index) => index as i32, - Err(e) => { - log::error!("failed to get interface index for '{}': {e}", iface); - continue; - } - }; - match bpf_attach_to_xdp(&mut skel, ifindex, Some(iface.as_str()), &config.network.ip_version) { - Ok(mode) => { - xdp_modes.push((iface.as_str(), mode.as_str())); - skels.push(Arc::new(skel)); - ifindices.push(ifindex); - } - Err(e) => { - // Check if error is EAFNOSUPPORT (error 97) - IPv6 might be disabled - let error_str = e.to_string(); - if error_str.contains("97") || error_str.contains("Address family not supported") { - log::warn!("Failed to attach XDP to '{}': {} (IPv6 disabled)", iface, e); - } else { - log::error!("Failed to attach XDP to '{}': {}", iface, e); - } - } - } - } - Err(e) => { - // Check for common BPF/kernel support errors - let error_str = e.to_string(); - if error_str.contains("Function not implemented") || - error_str.contains("ENOSYS") || - error_str.contains("error 38") || - error_str.contains("CONFIG_BPF_SYSCALL") { - log::warn!("BPF not supported by kernel for '{}': {}", iface, e); - } else { - log::warn!("failed to load BPF skeleton for '{}': {e}", iface); - } - } - } - } - } - - #[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] - { - log::warn!("BPF support disabled at build time; skipping XDP attachment"); - log::warn!("Set 'disable_xdp: true' in config.yaml to suppress this warning"); - } - } - - // Determine firewall backend based on config and availability - log::info!("Firewall mode configured: {}", firewall_mode); - - match firewall_mode { - FirewallMode::Xdp => { - // Forced XDP mode - if !skels.is_empty() { - firewall_backend = FirewallBackend::Xdp; - let _ = access_rules::init_access_rules_from_global(&skels); - log::info!("Using XDP/BPF firewall backend (forced)"); - } else { - log::error!("XDP mode forced but BPF not available"); - } - } - FirewallMode::Nftables => { - // Forced nftables mode - if NftablesFirewall::is_available() { - match NftablesFirewall::new() { - Ok(nft_fw) => { - firewall_backend = FirewallBackend::Nftables; - nftables_firewall = Some(Arc::new(Mutex::new(nft_fw))); - log::info!("Using nftables firewall backend (forced)"); - } - Err(e) => { - log::error!("Failed to initialize nftables (forced): {}", e); - } - } - } else { - log::error!("nftables mode forced but nft command not available"); - } - } - FirewallMode::Iptables => { - // Forced iptables mode - if IptablesFirewall::is_available() { - match IptablesFirewall::new() { - Ok(ipt_fw) => { - firewall_backend = FirewallBackend::Iptables; - iptables_firewall = Some(Arc::new(Mutex::new(ipt_fw))); - log::info!("Using iptables firewall backend (forced)"); - } - Err(e) => { - log::error!("Failed to initialize iptables (forced): {}", e); - } - } - } else { - log::error!("iptables mode forced but iptables not available"); - } - } - FirewallMode::None => { - log::info!("Firewall disabled by config - userland enforcement only"); - } - FirewallMode::Auto => { - // Auto mode: try XDP > nftables > iptables > none - if !config.network.disable_xdp && !skels.is_empty() { - firewall_backend = FirewallBackend::Xdp; - let _ = access_rules::init_access_rules_from_global(&skels); - log::info!("Using XDP/BPF firewall backend"); - } else { - if config.network.disable_xdp { - log::info!("XDP disabled - trying fallback backends"); - } else { - log::warn!("XDP/BPF not available - trying fallback backends"); - } - - // Try nftables first - if NftablesFirewall::is_available() { - match NftablesFirewall::new() { - Ok(nft_fw) => { - firewall_backend = FirewallBackend::Nftables; - nftables_firewall = Some(Arc::new(Mutex::new(nft_fw))); - log::info!("Using nftables firewall backend"); - } - Err(e) => { - log::warn!("Failed to initialize nftables: {}", e); - } - } - } else { - log::warn!("nftables (nft) not available on system"); - } - - // If nftables failed, try iptables - if firewall_backend == FirewallBackend::None { - if IptablesFirewall::is_available() { - match IptablesFirewall::new() { - Ok(ipt_fw) => { - firewall_backend = FirewallBackend::Iptables; - iptables_firewall = Some(Arc::new(Mutex::new(ipt_fw))); - log::info!("Using iptables firewall backend"); - } - Err(e) => { - log::warn!("Failed to initialize iptables: {}", e); - } - } - } else { - log::warn!("iptables not available on system"); - } - } - - if firewall_backend == FirewallBackend::None { - log::warn!("No firewall backend available - access rules will be enforced in userland only"); - } - } - } - } - - // Create BPF statistics collector - let bpf_stats_collector = BpfStatsCollector::new(skels.clone(), config.bpf_stats.enabled); - - // Create TCP fingerprinting collector - let tcp_fingerprint_collector = TcpFingerprintCollector::new_with_config( - skels.clone(), - TcpFingerprintConfig::from_cli_config(&config.tcp_fingerprint) - ); - - // Set global TCP fingerprint collector for proxy access - crate::utils::tcp_fingerprint::set_global_tcp_fingerprint_collector(tcp_fingerprint_collector.clone()); - - // Initialize access rules for nftables or iptables backend if active - if firewall_backend == FirewallBackend::Nftables { - if let Some(ref nft_fw) = nftables_firewall { - if let Err(e) = access_rules::init_access_rules_nftables(nft_fw) { - log::error!("Failed to initialize nftables access rules: {}", e); - } - } - } else if firewall_backend == FirewallBackend::Iptables { - if let Some(ref ipt_fw) = iptables_firewall { - if let Err(e) = access_rules::init_access_rules_iptables(ipt_fw) { - log::error!("Failed to initialize iptables access rules: {}", e); - } - } - } - - let state = AppState { - skels: skels.clone(), - ifindices: ifindices.clone(), - bpf_stats_collector, - tcp_fingerprint_collector, - firewall_backend, - nftables_firewall: nftables_firewall.clone(), - iptables_firewall: iptables_firewall.clone(), - }; - - // Start the captcha verification server in a separate task (skip in agent mode to save memory) - let captcha_server_enabled = cfg!(feature = "proxy") && config.mode != "agent"; - if captcha_server_enabled { - #[cfg(feature = "proxy")] - { - tokio::spawn(async move { - if let Err(e) = captcha_server::start_captcha_server().await { - error!("Captcha server error: {}", e); - } - }); - } - } - - // Start embedded ACME server if enabled (skip in agent mode - no TLS termination needed) - #[cfg(feature = "proxy")] - if config.acme.enabled && config.mode != "agent" { - let acme_config = config.acme.clone(); - let pingora_config = config.pingora.clone(); - let redis_config = config.redis.clone(); - - tokio::spawn(async move { - use crate::acme::embedded::{EmbeddedAcmeServer, EmbeddedAcmeConfig}; - use std::path::PathBuf; - - // Use upstreams path from pingora configuration - let upstreams_path = PathBuf::from(&pingora_config.upstreams_conf); - - // Determine email - let email = acme_config.email - .unwrap_or_else(|| "admin@example.com".to_string()); - - // Determine Redis URL - let redis_url = acme_config.redis_url - .or_else(|| if redis_config.url.is_empty() { None } else { Some(redis_config.url) }); - - // Create Redis SSL config if available - let redis_ssl = redis_config.ssl.map(|ssl| crate::acme::config::RedisSslConfig { - ca_cert_path: ssl.ca_cert_path, - client_cert_path: ssl.client_cert_path, - client_key_path: ssl.client_key_path, - insecure: ssl.insecure, - }); - - // Log storage configuration for debugging - if let Some(ref st) = acme_config.storage_type { - info!("ACME storage_type from config: '{}'", st); - } else { - warn!("ACME storage_type not set in config, will auto-detect from redis_url"); - } - if let Some(ref ru) = redis_url { - info!("ACME redis_url: '{}'", ru); - } else { - warn!("ACME redis_url not set"); - } - - let embedded_acme_config = EmbeddedAcmeConfig { - port: acme_config.port, - bind_ip: "127.0.0.1".to_string(), - upstreams_path, - email, - storage_path: PathBuf::from(&acme_config.storage_path), - storage_type: acme_config.storage_type.clone(), - development: acme_config.development, - redis_url, - redis_ssl, - }; - - // Clone config for HTTP server before moving it - let http_server_config = embedded_acme_config.clone(); - - let acme_server = EmbeddedAcmeServer::new(embedded_acme_config); - - // Initialize domain reader - if let Err(e) = acme_server.init_domain_reader().await { - error!("Failed to initialize ACME domain reader: {}", e); - return; - } - - // Start the HTTP server first (in background) so endpoint checks can succeed - tokio::spawn(async move { - let http_server = EmbeddedAcmeServer::new(http_server_config); - // Initialize domain reader for the HTTP server - if let Err(e) = http_server.init_domain_reader().await { - error!("Failed to initialize domain reader for HTTP server: {}", e); - return; - } - if let Err(e) = http_server.start_server().await { - error!("ACME server error: {}", e); - } - }); - - // Give the server a moment to start before processing certificates - tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; - - // Process certificates initially (endpoint check will retry if server not ready) - if let Err(e) = acme_server.process_certificates().await { - warn!("Failed to process initial certificates: {}", e); - } - }); - } - let (shutdown_tx, shutdown_rx) = watch::channel(false); - - // Initialize Redis manager if Redis URL is provided (skip in agent mode) - let redis_initialized = if config.mode != "agent" && !config.redis.url.is_empty() { - match redis::RedisManager::init(&config.redis.url, config.redis.prefix.clone(), config.redis.ssl.as_ref()).await { - Ok(_) => true, - Err(e) => { - log::error!("Failed to initialize Redis manager: {}", e); - false - } - } - } else { - false - }; - - // Initialize worker manager - let (mut worker_manager, _worker_shutdown_rx) = worker::WorkerManager::new(); - - #[cfg(feature = "proxy")] - { - // Set ACME config for certificate worker to use (skip in agent mode) - if config.mode != "agent" { - worker::certificate::set_acme_config(config.acme.clone()); - } - - // Register certificate worker only if Redis was successfully initialized (skip in agent mode) - if redis_initialized && config.mode != "agent" { - // Parse proxy_certificates from config file (under pingora section) - let certificate_path = if let Some(config_path) = &args.config { - std::fs::read_to_string(config_path) - .ok() - .and_then(|content| serde_yaml::from_str::(&content).ok()) - .and_then(|yaml| { - // Try pingora.proxy_certificates first, then fallback to root level - yaml.get("pingora") - .and_then(|pingora| pingora.get("proxy_certificates")) - .or_else(|| yaml.get("proxy_certificates")) - .and_then(|v| v.as_str().map(|s| s.to_string())) - }) - .unwrap_or_else(|| "/tmp/synapse-certs".to_string()) - } else { - "/tmp/synapse-certs".to_string() - }; - - // Set proxy_certificates path for ACME certificate saving - crate::acme::set_proxy_certificates_path(Some(certificate_path.clone())); - - let refresh_interval = 30; // 30 seconds default refresh interval - let worker_config = worker::WorkerConfig { - name: "certificate".to_string(), - interval_secs: refresh_interval, - enabled: true, - }; - - let upstreams_path = config.pingora.upstreams_conf.clone(); - let certificate_worker = worker::certificate::CertificateWorker::new( - certificate_path.clone(), - upstreams_path, - refresh_interval - ); - - if let Err(e) = worker_manager.register_worker(worker_config, certificate_worker) { - log::error!("Failed to register certificate worker: {}", e); - } - } - } - - // Validate API key if provided - if !config.platform.base_url.is_empty() && !config.platform.api_key.is_empty() { - if let Err(e) = validate_api_key( - &config.platform.base_url, - &config.platform.api_key, - ).await { - log::error!("API key validation failed: {}", e); - return Err(anyhow::anyhow!("API key validation failed: {}", e)); - } - } - - // Initialize content scanning from CLI config (skip in agent mode) - let content_scanner_enabled = config.mode != "agent" && config.content_scanning.enabled; - if content_scanner_enabled { - let content_scanning_config = ContentScanningConfig { - enabled: config.content_scanning.enabled, - clamav_server: config.content_scanning.clamav_server.clone(), - max_file_size: config.content_scanning.max_file_size, - scan_content_types: config.content_scanning.scan_content_types.clone(), - skip_extensions: config.content_scanning.skip_extensions.clone(), - scan_expression: config.content_scanning.scan_expression.clone(), - }; - if let Err(e) = init_content_scanner(content_scanning_config) { - log::warn!("Failed to initialize content scanner: {}", e); - } - } - - // Initialize access log sender configuration - let log_sender_config = LogSenderConfig { - enabled: config.platform.log_sending_enabled, - base_url: config.platform.base_url.clone(), - api_key: config.platform.api_key.clone(), - batch_size_limit: 5000, // Default: 5000 logs per batch - batch_size_bytes: 5 * 1024 * 1024, // Default: 5MB - batch_timeout_secs: 10, // Default: 10 seconds - include_request_body: false, // Default: disabled - max_body_size: config.platform.max_body_size, - }; - set_log_sender_config(log_sender_config); - - // Register log sender worker if log sending is enabled - let log_sender_enabled = config.platform.log_sending_enabled && !config.platform.api_key.is_empty(); - if log_sender_enabled { - - let check_interval = 1; // Check every 1 second - let worker_config = worker::WorkerConfig { - name: "log_sender".to_string(), - interval_secs: check_interval, - enabled: true, - }; - - let log_sender_worker = worker::log::LogSenderWorker::new(check_interval); - - if let Err(e) = worker_manager.register_worker(worker_config, log_sender_worker) { - log::error!("Failed to register log sender worker: {}", e); - } - } - - // Determine if we have API key for full functionality - let has_api_key = !config.platform.api_key.is_empty(); - - - // Build list of interfaces to attach - if has_api_key && !config.platform.base_url.is_empty() { - // Skip WAF wirefilter initialization in agent mode (only access rules needed for XDP) - if !is_agent_mode { - if let Err(e) = init_config( - config.platform.base_url.clone(), - config.platform.api_key.clone(), - ) - .await - { - log::error!("Failed to initialize HTTP filter with config: {}", e); - log::error!("Aborting startup because WAF config could not be loaded"); - return Err(e); - } - waf_enabled = true; - } - - // Initialize threat intelligence client (skip in agent mode to save memory) - if !is_agent_mode { - threat_client_enabled = true; - let has_threat = !config.platform.threat.url.is_empty() || config.platform.threat.path.is_some(); - let has_geoip = !config.geoip.country.url.is_empty() - || !config.geoip.asn.url.is_empty() - || !config.geoip.city.url.is_empty() - || config.geoip.country.path.is_some() - || config.geoip.asn.path.is_some() - || config.geoip.city.path.is_some(); - - if has_threat || has_geoip { - if let Err(e) = threat::init_threat_client( - config.platform.threat.path.clone(), - config.geoip.country.path.clone(), - config.geoip.asn.path.clone(), - config.geoip.city.path.clone(), - ) - .await - { - log::warn!("Failed to initialize threat client: {}", e); - } else { - // Register Threat MMDB refresh worker if configured - let refresh_interval = config.platform.threat.refresh_secs.unwrap_or(300); - if !config.platform.threat.url.is_empty() && refresh_interval > 0 { - let worker_config = worker::WorkerConfig { - name: "threat_mmdb".to_string(), - interval_secs: refresh_interval, - enabled: true, - }; - let worker = worker::threat_mmdb::ThreatMmdbWorker::new( - refresh_interval, - config.platform.threat.url.clone(), - config.platform.threat.path.clone(), - config.platform.threat.headers.clone(), - config.platform.api_key.clone(), - ); - if let Err(e) = worker_manager.register_worker(worker_config, worker) { - log::error!("Failed to register threat MMDB worker: {}", e); - } - } - - // Register GeoIP MMDB refresh workers if configured - let refresh_interval = config.geoip.refresh_secs; - - // Country database worker - if !config.geoip.country.url.is_empty() && refresh_interval > 0 { - let worker_config = worker::WorkerConfig { - name: "geoip_country_mmdb".to_string(), - interval_secs: refresh_interval, - enabled: true, - }; - let worker = worker::geoip_mmdb::GeoipMmdbWorker::new( - refresh_interval, - config.geoip.country.url.clone(), - "".to_string(), // versions_url not used for geoip - config.geoip.country.path.clone(), - config.geoip.country.headers.clone(), - worker::geoip_mmdb::GeoipDatabaseType::Country, - ); - if let Err(e) = worker_manager.register_worker(worker_config, worker) { - log::error!("Failed to register GeoIP Country MMDB worker: {}", e); - } - } - - // ASN database worker - if !config.geoip.asn.url.is_empty() && refresh_interval > 0 { - let worker_config = worker::WorkerConfig { - name: "geoip_asn_mmdb".to_string(), - interval_secs: refresh_interval, - enabled: true, - }; - let worker = worker::geoip_mmdb::GeoipMmdbWorker::new( - refresh_interval, - config.geoip.asn.url.clone(), - "".to_string(), - config.geoip.asn.path.clone(), - config.geoip.asn.headers.clone(), - worker::geoip_mmdb::GeoipDatabaseType::Asn, - ); - if let Err(e) = worker_manager.register_worker(worker_config, worker) { - log::error!("Failed to register GeoIP ASN MMDB worker: {}", e); - } - } - - // City database worker - if !config.geoip.city.url.is_empty() && refresh_interval > 0 { - let worker_config = worker::WorkerConfig { - name: "geoip_city_mmdb".to_string(), - interval_secs: refresh_interval, - enabled: true, - }; - let worker = worker::geoip_mmdb::GeoipMmdbWorker::new( - refresh_interval, - config.geoip.city.url.clone(), - "".to_string(), - config.geoip.city.path.clone(), - config.geoip.city.headers.clone(), - worker::geoip_mmdb::GeoipDatabaseType::City, - ); - if let Err(e) = worker_manager.register_worker(worker_config, worker) { - log::error!("Failed to register GeoIP City MMDB worker: {}", e); - } - } - } - } - } - - // Initialize captcha client if configuration is provided (skip in agent mode) - if let (Some(site_key), Some(secret_key), Some(jwt_secret)) = ( - &config.platform.captcha.site_key, - &config.platform.captcha.secret_key, - &config.platform.captcha.jwt_secret, - ) { - let captcha_config = CaptchaConfig { - site_key: site_key.clone(), - secret_key: secret_key.clone(), - jwt_secret: jwt_secret.clone(), - provider: CaptchaProvider::from_str(&config.platform.captcha.provider) - .unwrap_or(CaptchaProvider::HCaptcha), - token_ttl_seconds: config.platform.captcha.token_ttl, - validation_cache_ttl_seconds: config.platform.captcha.cache_ttl, - }; - - if let Err(e) = init_captcha_client(captcha_config).await { - log::warn!("Failed to initialize captcha client: {}", e); - } else { - captcha_client_enabled = true; - start_cache_cleanup_task().await; - } - } - } else { - // LOCAL MODE: Load security rules from local file - let security_rules_path = args.security_rules_config.clone(); - - // Load and initialize from local file - match worker::config::load_config_from_file(&security_rules_path).await { - Ok(config_response) => { - // Store config globally for access rules - worker::config::set_global_config(config_response.config.clone()); - - // Initialize WAF with local rules (skip in agent mode to save memory) - if !is_agent_mode { - if let Err(e) = crate::waf::wirefilter::load_waf_rules(config_response.config.waf_rules.rules).await { - log::warn!("Failed to load WAF rules from local file: {}", e); - } else { - waf_enabled = true; - } - } - - // Update access rules in XDP if available - if !state.skels.is_empty() { - if let Err(e) = crate::access_rules::apply_rules_from_global_with_state(&state.skels, is_agent_mode) { - log::warn!("Failed to apply access rules from local file to XDP: {}", e); - } - } - } - Err(e) => { - log::error!("Failed to load security rules from local file: {}", e); - } - } - } - - // Register agent status worker (register + heartbeat) if unified event sending is enabled - if log_sender_enabled { - let hostname = std::env::var("HOSTNAME") - .ok() - .filter(|value| !value.trim().is_empty()) - .unwrap_or_else(|| gethostname::gethostname().to_string_lossy().into_owned()); - - let agent_name = std::env::var("AGENT_NAME") - .ok() - .filter(|value| !value.trim().is_empty()) - .unwrap_or_else(|| hostname.clone()); - - let workspace_id = config.platform.workspace_id.clone(); - let agent_id = build_agent_id(&agent_name, &workspace_id); - - let tags = std::env::var("AGENT_TAGS") - .ok() - .map(|value| { - value - .split(',') - .map(|tag| tag.trim().to_string()) - .filter(|tag| !tag.is_empty()) - .collect::>() - }) - .unwrap_or_default(); - - let mut capabilities = Vec::new(); - if log_sender_enabled { - capabilities.push("log_sender".to_string()); - } - if config.bpf_stats.enabled { - capabilities.push("bpf_stats".to_string()); - } - if config.bpf_stats.enable_dropped_ip_events { - capabilities.push("bpf_stats_dropped_ip_events".to_string()); - } - if config.tcp_fingerprint.enabled { - capabilities.push("tcp_fingerprint".to_string()); - } - if config.tcp_fingerprint.enable_fingerprint_events { - capabilities.push("tcp_fingerprint_events".to_string()); - } - if content_scanner_enabled { - capabilities.push("content_scanner".to_string()); - } - if waf_enabled { - capabilities.push("waf".to_string()); - } - if threat_client_enabled { - capabilities.push("threat_client".to_string()); - } - if captcha_client_enabled { - capabilities.push("captcha_client".to_string()); - } - if !config.network.disable_xdp { - capabilities.push("xdp".to_string()); - } - - let interfaces = if !config.network.ifaces.is_empty() { - config.network.ifaces.clone() - } else if !config.network.iface.is_empty() { - vec![config.network.iface.clone()] - } else { - Vec::new() - }; - - let ip_addresses = std::env::var("AGENT_IPS") - .or_else(|_| std::env::var("AGENT_IP_ADDRESSES")) - .ok() - .map(|value| { - value - .split(',') - .map(|ip| ip.trim().to_string()) - .filter(|ip| !ip.is_empty()) - .collect::>() - }) - .unwrap_or_default(); - - let mut metadata = HashMap::new(); - metadata.insert("os".to_string(), std::env::consts::OS.to_string()); - metadata.insert("arch".to_string(), std::env::consts::ARCH.to_string()); - metadata.insert("version".to_string(), env!("CARGO_PKG_VERSION").to_string()); - metadata.insert("mode".to_string(), config.mode.clone()); - metadata.insert("platform_base_url".to_string(), config.platform.base_url.clone()); - - let identity = AgentStatusIdentity { - agent_id, - agent_name, - hostname, - version: env!("CARGO_PKG_VERSION").to_string(), - mode: config.mode.clone(), - tags, - capabilities, - interfaces, - ip_addresses, - metadata, - started_at, - }; - - let heartbeat_secs = std::env::var("AGENT_HEARTBEAT_SECS") - .ok() - .and_then(|value| value.parse::().ok()) - .unwrap_or(30); - - let worker_config = worker::WorkerConfig { - name: "agent_status".to_string(), - interval_secs: heartbeat_secs, - enabled: true, - }; - - let agent_status_worker = AgentStatusWorker::new(identity, heartbeat_secs); - if let Err(e) = worker_manager.register_worker(worker_config, agent_status_worker) { - log::error!("Failed to register agent status worker: {}", e); - } - } - - // Access rules were already initialized after XDP attachment above - - // Register config worker to fetch and apply configuration periodically - if has_api_key && !config.platform.base_url.is_empty() { - let refresh_interval = 10; // 10 seconds config refresh interval - let worker_config = worker::WorkerConfig { - name: "config".to_string(), - interval_secs: refresh_interval, - enabled: true, - }; - - let config_worker = worker::config::ConfigWorker::new( - config.platform.base_url.clone(), - config.platform.api_key.clone(), - refresh_interval, - state.skels.clone(), - args.security_rules_config.clone(), - ).with_agent_mode(config.mode == "agent") - .with_nftables(state.nftables_firewall.clone()) - .with_iptables(state.iptables_firewall.clone()); - - if let Err(e) = worker_manager.register_worker(worker_config, config_worker) { - log::error!("Failed to register config worker: {}", e); - } - } else { - // In local mode, register config worker that loads from file at startup only - let refresh_interval = 10; - let worker_config = worker::WorkerConfig { - name: "config".to_string(), - interval_secs: refresh_interval, - enabled: true, - }; - - let config_worker = worker::config::ConfigWorker::new( - String::new(), - String::new(), - refresh_interval, - state.skels.clone(), - args.security_rules_config.clone(), - ).with_agent_mode(config.mode == "agent") - .with_nftables(state.nftables_firewall.clone()) - .with_iptables(state.iptables_firewall.clone()); - - if let Err(e) = worker_manager.register_worker(worker_config, config_worker) { - log::error!("Failed to register config worker: {}", e); - } - } - - // Start the old Pingora proxy system in a separate thread (non-blocking) - // Only start if mode is "proxy" (disabled in agent mode) - #[cfg(feature = "proxy")] - if config.mode == "proxy" { - let bpf_stats_config = config.bpf_stats.clone(); - let logging_config = config.logging.clone(); - let platform_config = config.platform.clone(); - let geoip_config = config.geoip.clone(); - let network_config = config.network.clone(); - let tcp_fingerprint_config = config.tcp_fingerprint.clone(); - let pingora_config = config.pingora.clone(); - std::thread::spawn(move || { - http_proxy::start::run_with_config(Some(crate::cli::Config { - mode: "proxy".to_string(), - worker_threads: None, - redis: Default::default(), - network: network_config, - platform: platform_config, - geoip: geoip_config, - content_scanning: Default::default(), - logging: logging_config, - bpf_stats: bpf_stats_config, - tcp_fingerprint: tcp_fingerprint_config, - daemon: Default::default(), - pingora: pingora_config, - acme: Default::default(), - })); - }); - } - - // Start BPF statistics logging task - let bpf_stats_handle = if config.bpf_stats.enabled && !state.skels.is_empty() { - let collector = state.bpf_stats_collector.clone(); - let log_interval = config.bpf_stats.log_interval_secs; - let shutdown = shutdown_rx.clone(); - Some(start_bpf_stats_logging(collector, log_interval, shutdown)) - } else { - None - }; - - // Start dropped IP events logging task - let dropped_ip_events_handle = if config.bpf_stats.enabled && - config.bpf_stats.enable_dropped_ip_events && - !state.skels.is_empty() { - let collector = state.bpf_stats_collector.clone(); - let log_interval = config.bpf_stats.dropped_ip_events_interval_secs; - let shutdown = shutdown_rx.clone(); - Some(start_dropped_ip_events_logging(collector, log_interval, shutdown)) - } else { - None - }; - - // Start TCP fingerprinting statistics logging task - let tcp_fingerprint_stats_handle = if config.tcp_fingerprint.enabled && !state.skels.is_empty() { - let collector = state.tcp_fingerprint_collector.clone(); - let log_interval = config.tcp_fingerprint.log_interval_secs; - let shutdown = shutdown_rx.clone(); - let state_clone = Arc::new(state.clone()); - Some(start_tcp_fingerprint_stats_logging(collector, log_interval, shutdown, state_clone)) - } else { - None - }; - - // Start TCP fingerprinting events logging task - let tcp_fingerprint_events_handle = if config.tcp_fingerprint.enabled && - config.tcp_fingerprint.enable_fingerprint_events && - !state.skels.is_empty() { - let collector = state.tcp_fingerprint_collector.clone(); - let log_interval = config.tcp_fingerprint.fingerprint_events_interval_secs; - let shutdown = shutdown_rx.clone(); - let state_clone = Arc::new(state.clone()); - Some(start_tcp_fingerprint_events_logging(collector, log_interval, shutdown, state_clone)) - } else { - None - }; - - // Log startup summary as JSON - let xdp_info: serde_json::Value = if xdp_modes.is_empty() { - serde_json::json!(null) - } else if xdp_modes.len() == 1 { - serde_json::json!(xdp_modes[0].1) - } else { - serde_json::json!(xdp_modes.iter().map(|(iface, mode)| { - serde_json::json!({ "interface": iface, "mode": mode }) - }).collect::>()) - }; - - let worker_threads = match config.worker_threads { - Some(threads) => serde_json::json!(threads), - None if is_agent_mode => serde_json::json!(0), - None => serde_json::json!(null), - }; - - let startup_info = serde_json::json!({ - "event": "startup_complete", - "mode": config.mode, - "runtime": if config.worker_threads == Some(0) || (config.worker_threads.is_none() && is_agent_mode) { - "single_thread" - } else { - "multi_thread" - }, - "worker_threads": worker_threads, - "interfaces": &iface_names, - "xdp_enabled": !config.network.disable_xdp && !state.skels.is_empty(), - "xdp_mode": xdp_info, - "features": { - "bpf_filtering": !state.skels.is_empty(), - "bpf_stats": config.bpf_stats.enabled && !state.skels.is_empty(), - "tcp_fingerprint": config.tcp_fingerprint.enabled && !state.skels.is_empty(), - "waf": waf_enabled, - "threat_intel": threat_client_enabled, - "captcha_server": captcha_server_enabled, - "captcha_client": captcha_client_enabled, - "content_scanner": content_scanner_enabled, - "redis": redis_initialized, - "log_sender": log_sender_enabled, - "acme": cfg!(feature = "proxy") && config.acme.enabled && !is_agent_mode, - "proxy": cfg!(feature = "proxy") && config.mode != "agent", - }, - "api_configured": has_api_key, - }); - log::info!("{}", startup_info); - - signal::ctrl_c().await?; - log::info!("Shutdown signal received, stopping servers..."); - let _ = shutdown_tx.send(true); - - if let Some(handle) = bpf_stats_handle - && let Err(err) = handle.await - { - log::error!("bpf-stats task join error: {err}"); - } - - if let Some(handle) = dropped_ip_events_handle - && let Err(err) = handle.await - { - log::error!("dropped-ip-events task join error: {err}"); - } - - if let Some(handle) = tcp_fingerprint_stats_handle - && let Err(err) = handle.await - { - log::error!("tcp-fingerprint-stats task join error: {err}"); - } - - if let Some(handle) = tcp_fingerprint_events_handle - && let Err(err) = handle.await - { - log::error!("tcp-fingerprint-events task join error: {err}"); - } - - // Shutdown all workers - worker_manager.shutdown(); - worker_manager.wait_for_all().await; - log::info!("Proceeding with cleanup..."); - - // Detach XDP programs from interfaces - #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] - if !ifindices.is_empty() { - log::info!("Detaching XDP programs from {} interfaces...", ifindices.len()); - for ifindex in ifindices { - if let Err(e) = bpf_detach_from_xdp(ifindex) { - log::error!("Failed to detach XDP from interface {}: {}", ifindex, e); - } - } - } - - // Cleanup nftables rules if using nftables backend - if let Some(ref nft_fw) = nftables_firewall { - log::info!("Cleaning up nftables firewall rules..."); - if let Ok(fw) = nft_fw.lock() { - if let Err(e) = fw.cleanup() { - log::error!("Failed to cleanup nftables rules: {}", e); - } - } - } - - // Cleanup iptables rules if using iptables backend - if let Some(ref ipt_fw) = iptables_firewall { - log::info!("Cleaning up iptables firewall rules..."); - if let Ok(fw) = ipt_fw.lock() { - if let Err(e) = fw.cleanup() { - log::error!("Failed to cleanup iptables rules: {}", e); - } - } - } - - log::info!("Shutdown complete, exiting..."); - - // Force exit immediately - cleanup is done - std::process::exit(0); -} - -fn read_env_non_empty(name: &str) -> Option { - std::env::var(name) - .ok() - .map(|value| value.trim().to_string()) - .filter(|value| !value.is_empty()) -} - -fn build_agent_id(agent_name: &str, workspace_id: &str) -> String { - if read_env_non_empty("AGENT_ID").is_some() { - warn!("AGENT_ID is ignored; agent_id is derived from agent_name + workspace_id."); - } - - if workspace_id.trim().is_empty() { - warn!( - "WORKSPACE_ID not set; agent_id derived only from agent_name '{}'. Set WORKSPACE_ID \ -(or ARXIGNIS_WORKSPACE_ID) to avoid collisions across organizations.", - agent_name - ); - } - - let input = format!("{}:{}", workspace_id.trim(), agent_name.trim()); - let digest = Sha256::digest(input.as_bytes()); - format!("{:x}", digest) -} - -/// Start a background task that logs BPF statistics periodically -fn start_bpf_stats_logging( - collector: BpfStatsCollector, - log_interval_secs: u64, - mut shutdown: tokio::sync::watch::Receiver, -) -> tokio::task::JoinHandle<()> { - tokio::spawn(async move { - let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); - - loop { - tokio::select! { - _ = shutdown.changed() => { - if *shutdown.borrow() { break; } - } - _ = interval.tick() => { - if let Err(e) = collector.log_stats() { - log::warn!("Failed to log BPF statistics: {}", e); - } - } - } - } - - log::info!("BPF statistics logging task stopped"); - }) -} - -/// Start a background task that logs dropped IP events periodically -fn start_dropped_ip_events_logging( - collector: BpfStatsCollector, - log_interval_secs: u64, - mut shutdown: tokio::sync::watch::Receiver, -) -> tokio::task::JoinHandle<()> { - tokio::spawn(async move { - let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); - - loop { - tokio::select! { - _ = shutdown.changed() => { - if *shutdown.borrow() { break; } - } - _ = interval.tick() => { - if let Err(e) = collector.log_dropped_ip_events() { - log::warn!("Failed to log dropped IP events: {}", e); - } - } - } - } - - log::info!("Dropped IP events logging task stopped"); - }) -} - -/// Start a background task that logs TCP fingerprinting statistics periodically -fn start_tcp_fingerprint_stats_logging( - collector: TcpFingerprintCollector, - log_interval_secs: u64, - mut shutdown: tokio::sync::watch::Receiver, - _state: Arc, -) -> tokio::task::JoinHandle<()> { - tokio::spawn(async move { - let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); - - loop { - tokio::select! { - _ = shutdown.changed() => { - if *shutdown.borrow() { break; } - } - _ = interval.tick() => { - if let Err(e) = collector.log_stats() { - log::warn!("Failed to log TCP fingerprinting statistics: {}", e); - } - } - } - } - - log::info!("TCP fingerprinting statistics logging task stopped"); - }) -} - -/// Start a background task that logs TCP fingerprinting events periodically -fn start_tcp_fingerprint_events_logging( - collector: TcpFingerprintCollector, - log_interval_secs: u64, - mut shutdown: tokio::sync::watch::Receiver, - _state: Arc, -) -> tokio::task::JoinHandle<()> { - tokio::spawn(async move { - let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); - - loop { - tokio::select! { - _ = shutdown.changed() => { - if *shutdown.borrow() { break; } - } - _ = interval.tick() => { - if let Err(e) = collector.log_fingerprint_events() { - log::warn!("Failed to log TCP fingerprinting events: {}", e); - } - } - } - } - - log::info!("TCP fingerprinting events logging task stopped"); - }) -} +use std::collections::HashMap; +use std::sync::Arc; +use std::str::FromStr; +use std::fs::File; +#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] +use std::mem::MaybeUninit; + +use anyhow::{Context, Result}; +use clap::Parser; +use daemonize::Daemonize; +#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] +use libbpf_rs::skel::{OpenSkel, SkelBuilder}; +#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] +use nix::net::if_::if_nametoindex; +use chrono::Utc; +use sha2::{Digest, Sha256}; + +#[global_allocator] +static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; + +pub mod access_log; +pub mod access_rules; +pub mod agent_status; +pub mod app_state; +#[cfg(feature = "proxy")] +pub mod captcha_server; +pub mod cli; +pub mod content_scanning; +#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] +pub mod firewall; +#[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] +#[path = "firewall_noop.rs"] +pub mod firewall; +pub mod http_client; +pub mod waf; +pub mod threat; +pub mod redis; +#[cfg(feature = "proxy")] +pub mod proxy_protocol; +pub mod authcheck; +#[cfg(feature = "proxy")] +pub mod http_proxy; +#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] +pub mod bpf { + // Include the skeleton generated by build.rs into OUT_DIR at compile time + include!(concat!(env!("OUT_DIR"), "/filter.skel.rs")); +} +#[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] +#[path = "bpf_stub.rs"] +pub mod bpf; + +#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] +pub mod bpf_stats; +#[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] +#[path = "bpf_stats_noop.rs"] +pub mod bpf_stats; +pub mod ja4_plus; +pub mod utils; +pub mod worker; +#[cfg(feature = "proxy")] +pub mod acme; + +use tokio::signal; +use tokio::sync::watch; +#[cfg(feature = "proxy")] +use log::{error, info}; +use log::warn; + +use crate::app_state::AppState; +use crate::bpf_stats::BpfStatsCollector; +use crate::utils::tcp_fingerprint::TcpFingerprintCollector; +use crate::utils::tcp_fingerprint::TcpFingerprintConfig; +use crate::cli::{Args, Config}; +use crate::waf::wirefilter::init_config; +use crate::content_scanning::{init_content_scanner, ContentScanningConfig}; +#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] +use crate::utils::bpf_utils::bpf_attach_to_xdp; +#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] +use crate::utils::bpf_utils::bpf_detach_from_xdp; + +use crate::access_log::LogSenderConfig; +use crate::worker::log::set_log_sender_config; +use crate::worker::agent_status::AgentStatusWorker; +use crate::agent_status::AgentStatusIdentity; +use crate::authcheck::validate_api_key; +use crate::http_client::init_global_client; +use crate::waf::actions::captcha::{CaptchaConfig, CaptchaProvider, init_captcha_client, start_cache_cleanup_task}; + +fn main() -> Result<()> { + // Initialize rustls crypto provider early (must be done before any rustls operations) + rustls::crypto::ring::default_provider() + .install_default() + .map_err(|e| anyhow::anyhow!("Failed to install rustls crypto provider: {:?}", e))?; + + let args = Args::parse(); + // Handle clear certificate command (runs before loading full config) + if let Some(certificate_name) = &args.clear_certificate { + #[cfg(feature = "proxy")] + { + // Initialize minimal runtime for async operations + let rt = tokio::runtime::Runtime::new() + .context("Failed to create tokio runtime")?; + + // Load minimal config for Redis connection + let config = Config::load_from_args(&args) + .context("Failed to load configuration")?; + + // Initialize Redis if configured + if !config.redis.url.is_empty() { + rt.block_on(crate::redis::RedisManager::init( + &config.redis.url, + config.redis.prefix.clone(), + config.redis.ssl.as_ref(), + )) + .context("Failed to initialize Redis manager")?; + } + + // Get certificate path from config + let certificate_path = config + .pingora + .proxy_certificates + .clone() + .unwrap_or_else(|| "/etc/synapse/certs".to_string()); + + // Clear the certificate + rt.block_on(crate::worker::certificate::clear_certificate( + certificate_name, + &certificate_path, + ))?; + + return Ok(()); + } + + #[cfg(not(feature = "proxy"))] + { + let _ = certificate_name; + return Err(anyhow::anyhow!( + "clear-certificate is not available in agent-only builds" + )); + } + } + + // API key is optional - allow running in local mode without it + + // Load configuration + let config = Config::load_from_args(&args) + .context("Failed to load configuration")?; + + if config.mode == "proxy" && !cfg!(feature = "proxy") { + return Err(anyhow::anyhow!( + "proxy mode is not supported in agent-only builds (build with the `proxy` feature)" + )); + } + + // Handle daemonization before starting tokio runtime + if config.daemon.enabled { + let stdout = File::create(&config.daemon.stdout) + .with_context(|| format!("Failed to create stdout file: {}", config.daemon.stdout))?; + let stderr = File::create(&config.daemon.stderr) + .with_context(|| format!("Failed to create stderr file: {}", config.daemon.stderr))?; + + let mut daemonize = Daemonize::new() + .pid_file(&config.daemon.pid_file) + .chown_pid_file(config.daemon.chown_pid_file) + .working_directory(&config.daemon.working_directory) + .stdout(stdout) + .stderr(stderr); + + if let Some(user) = &config.daemon.user { + daemonize = daemonize.user(user.as_str()); + } + + if let Some(group) = &config.daemon.group { + daemonize = daemonize.group(group.as_str()); + } + + match daemonize.start() { + Ok(_) => { + // We're now in the daemon process, continue with application startup + } + Err(e) => { + eprintln!("Failed to daemonize: {}", e); + return Err(anyhow::anyhow!("Daemonization failed: {}", e)); + } + } + } + + // Set RUST_LOG environment variable from config so other modules can use it + let log_level = if !config.logging.level.is_empty() { + config.logging.level.to_lowercase() + } else { + match args.log_level { + crate::cli::LogLevel::Error => "error", + crate::cli::LogLevel::Warn => "warn", + crate::cli::LogLevel::Info => "info", + crate::cli::LogLevel::Debug => "debug", + crate::cli::LogLevel::Trace => "trace", + }.to_string() + }; + unsafe { + std::env::set_var("RUST_LOG", &log_level); + } + + // Initialize logger using config level (CLI overrides if provided explicitly) + // Note: env_logger writes to stderr by default, which is standard practice + { + use env_logger::Env; + let mut builder = env_logger::Builder::from_env(Env::default().default_filter_or("info")); + + // Use log level from config, or CLI if explicitly set + let level_filter = match log_level.as_str() { + "error" => log::LevelFilter::Error, + "warn" => log::LevelFilter::Warn, + "info" => log::LevelFilter::Info, + "debug" => log::LevelFilter::Debug, + "trace" => log::LevelFilter::Trace, + _ => args.log_level.to_level_filter(), + }; + builder.filter_level(level_filter); + builder.format_timestamp_secs(); + + // In daemon mode, write to stdout instead of stderr for better log separation + if config.daemon.enabled { + builder.target(env_logger::Target::Stdout); + } + + builder.try_init().ok(); + } + + // Start the tokio runtime and run the async application + let runtime = match config.worker_threads { + // Explicit config: use specified thread count (0 = single-threaded) + Some(0) => { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()? + } + Some(threads) => { + tokio::runtime::Builder::new_multi_thread() + .worker_threads(threads) + .enable_all() + .build()? + } + None if config.mode == "agent" => { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()? + } + None => { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()? + } + }; + runtime.block_on(async_main(args, config)) +} + +#[allow(clippy::too_many_lines)] +async fn async_main(args: Args, config: Config) -> Result<()> { + let started_at = Utc::now(); + + if config.daemon.enabled { + log::info!("Running in daemon mode (PID file: {})", config.daemon.pid_file); + } + + // Initialize global HTTP client with keepalive configuration + if let Err(e) = init_global_client() { + log::warn!("Failed to initialize global HTTP client: {}", e); + } + + // Track enabled features for startup summary + let is_agent_mode = config.mode == "agent"; + let mut waf_enabled = false; + let mut threat_client_enabled = false; + let mut captcha_client_enabled = false; + + + let iface_names: Vec = if !config.network.ifaces.is_empty() { + config.network.ifaces.clone() + } else { + vec![config.network.iface.clone()] + }; + + use crate::firewall::{FirewallBackend, FirewallMode, NftablesFirewall, IptablesFirewall}; + use std::sync::Mutex; + + #[allow(unused_mut)] + let mut skels: Vec>> = Vec::new(); + #[allow(unused_mut)] + let mut ifindices: Vec = Vec::new(); + let mut firewall_backend = FirewallBackend::None; + let mut nftables_firewall: Option>> = None; + let mut iptables_firewall: Option>> = None; + let firewall_mode = config.network.firewall_mode; + + // Track XDP modes per interface for startup summary + #[allow(unused_mut)] + let mut xdp_modes: Vec<(&str, &str)> = Vec::new(); + + if config.network.disable_xdp { + log::warn!("XDP disabled by config, will use nftables fallback"); + } else { + #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] + { + // Suppress libbpf output - errors will be handled gracefully with fallback to nftables/iptables + libbpf_rs::set_print(None); + + for iface in &iface_names { + let boxed_open: Box> = Box::new(MaybeUninit::uninit()); + let open_object: &'static mut MaybeUninit = Box::leak(boxed_open); + let skel_builder = bpf::FilterSkelBuilder::default(); + match skel_builder.open(open_object).and_then(|o| o.load()) { + Ok(mut skel) => { + let ifindex = match if_nametoindex(iface.as_str()) { + Ok(index) => index as i32, + Err(e) => { + log::error!("failed to get interface index for '{}': {e}", iface); + continue; + } + }; + match bpf_attach_to_xdp(&mut skel, ifindex, Some(iface.as_str()), &config.network.ip_version) { + Ok(mode) => { + xdp_modes.push((iface.as_str(), mode.as_str())); + skels.push(Arc::new(skel)); + ifindices.push(ifindex); + } + Err(e) => { + // Check if error is EAFNOSUPPORT (error 97) - IPv6 might be disabled + let error_str = e.to_string(); + if error_str.contains("97") || error_str.contains("Address family not supported") { + log::warn!("Failed to attach XDP to '{}': {} (IPv6 disabled)", iface, e); + } else { + log::error!("Failed to attach XDP to '{}': {}", iface, e); + } + } + } + } + Err(e) => { + // Check for common BPF/kernel support errors + let error_str = e.to_string(); + if error_str.contains("Function not implemented") || + error_str.contains("ENOSYS") || + error_str.contains("error 38") || + error_str.contains("CONFIG_BPF_SYSCALL") { + log::warn!("BPF not supported by kernel for '{}': {}", iface, e); + } else { + log::warn!("failed to load BPF skeleton for '{}': {e}", iface); + } + } + } + } + } + + #[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] + { + log::warn!("BPF support disabled at build time; skipping XDP attachment"); + log::warn!("Set 'disable_xdp: true' in config.yaml to suppress this warning"); + } + } + + // Determine firewall backend based on config and availability + log::info!("Firewall mode configured: {}", firewall_mode); + + match firewall_mode { + FirewallMode::Xdp => { + // Forced XDP mode + if !skels.is_empty() { + firewall_backend = FirewallBackend::Xdp; + let _ = access_rules::init_access_rules_from_global(&skels); + log::info!("Using XDP/BPF firewall backend (forced)"); + } else { + log::error!("XDP mode forced but BPF not available"); + } + } + FirewallMode::Nftables => { + // Forced nftables mode + if NftablesFirewall::is_available() { + match NftablesFirewall::new() { + Ok(nft_fw) => { + firewall_backend = FirewallBackend::Nftables; + nftables_firewall = Some(Arc::new(Mutex::new(nft_fw))); + log::info!("Using nftables firewall backend (forced)"); + } + Err(e) => { + log::error!("Failed to initialize nftables (forced): {}", e); + } + } + } else { + log::error!("nftables mode forced but nft command not available"); + } + } + FirewallMode::Iptables => { + // Forced iptables mode + if IptablesFirewall::is_available() { + match IptablesFirewall::new() { + Ok(ipt_fw) => { + firewall_backend = FirewallBackend::Iptables; + iptables_firewall = Some(Arc::new(Mutex::new(ipt_fw))); + log::info!("Using iptables firewall backend (forced)"); + } + Err(e) => { + log::error!("Failed to initialize iptables (forced): {}", e); + } + } + } else { + log::error!("iptables mode forced but iptables not available"); + } + } + FirewallMode::None => { + log::info!("Firewall disabled by config - userland enforcement only"); + } + FirewallMode::Auto => { + // Auto mode: try XDP > nftables > iptables > none + if !config.network.disable_xdp && !skels.is_empty() { + firewall_backend = FirewallBackend::Xdp; + let _ = access_rules::init_access_rules_from_global(&skels); + log::info!("Using XDP/BPF firewall backend"); + } else { + if config.network.disable_xdp { + log::info!("XDP disabled - trying fallback backends"); + } else { + log::warn!("XDP/BPF not available - trying fallback backends"); + } + + // Try nftables first + if NftablesFirewall::is_available() { + match NftablesFirewall::new() { + Ok(nft_fw) => { + firewall_backend = FirewallBackend::Nftables; + nftables_firewall = Some(Arc::new(Mutex::new(nft_fw))); + log::info!("Using nftables firewall backend"); + } + Err(e) => { + log::warn!("Failed to initialize nftables: {}", e); + } + } + } else { + log::warn!("nftables (nft) not available on system"); + } + + // If nftables failed, try iptables + if firewall_backend == FirewallBackend::None { + if IptablesFirewall::is_available() { + match IptablesFirewall::new() { + Ok(ipt_fw) => { + firewall_backend = FirewallBackend::Iptables; + iptables_firewall = Some(Arc::new(Mutex::new(ipt_fw))); + log::info!("Using iptables firewall backend"); + } + Err(e) => { + log::warn!("Failed to initialize iptables: {}", e); + } + } + } else { + log::warn!("iptables not available on system"); + } + } + + if firewall_backend == FirewallBackend::None { + log::warn!("No firewall backend available - access rules will be enforced in userland only"); + } + } + } + } + + // Create BPF statistics collector + let bpf_stats_collector = BpfStatsCollector::new(skels.clone(), config.bpf_stats.enabled); + + // Create TCP fingerprinting collector + let tcp_fingerprint_collector = TcpFingerprintCollector::new_with_config( + skels.clone(), + TcpFingerprintConfig::from_cli_config(&config.tcp_fingerprint) + ); + + // Set global TCP fingerprint collector for proxy access + crate::utils::tcp_fingerprint::set_global_tcp_fingerprint_collector(tcp_fingerprint_collector.clone()); + + // Initialize access rules for nftables or iptables backend if active + if firewall_backend == FirewallBackend::Nftables { + if let Some(ref nft_fw) = nftables_firewall { + if let Err(e) = access_rules::init_access_rules_nftables(nft_fw) { + log::error!("Failed to initialize nftables access rules: {}", e); + } + } + } else if firewall_backend == FirewallBackend::Iptables { + if let Some(ref ipt_fw) = iptables_firewall { + if let Err(e) = access_rules::init_access_rules_iptables(ipt_fw) { + log::error!("Failed to initialize iptables access rules: {}", e); + } + } + } + + let state = AppState { + skels: skels.clone(), + ifindices: ifindices.clone(), + bpf_stats_collector, + tcp_fingerprint_collector, + firewall_backend, + nftables_firewall: nftables_firewall.clone(), + iptables_firewall: iptables_firewall.clone(), + }; + + // Start the captcha verification server in a separate task (skip in agent mode to save memory) + let captcha_server_enabled = cfg!(feature = "proxy") && config.mode != "agent"; + if captcha_server_enabled { + #[cfg(feature = "proxy")] + { + tokio::spawn(async move { + if let Err(e) = captcha_server::start_captcha_server().await { + error!("Captcha server error: {}", e); + } + }); + } + } + + // Start embedded ACME server if enabled (skip in agent mode - no TLS termination needed) + #[cfg(feature = "proxy")] + if config.acme.enabled && config.mode != "agent" { + let acme_config = config.acme.clone(); + let pingora_config = config.pingora.clone(); + let redis_config = config.redis.clone(); + + tokio::spawn(async move { + use crate::acme::embedded::{EmbeddedAcmeServer, EmbeddedAcmeConfig}; + use std::path::PathBuf; + + // Use upstreams path from pingora configuration + let upstreams_path = PathBuf::from(&pingora_config.upstreams_conf); + + // Determine email + let email = acme_config.email + .unwrap_or_else(|| "admin@example.com".to_string()); + + // Determine Redis URL + let redis_url = acme_config.redis_url + .or_else(|| if redis_config.url.is_empty() { None } else { Some(redis_config.url) }); + + // Create Redis SSL config if available + let redis_ssl = redis_config.ssl.map(|ssl| crate::acme::config::RedisSslConfig { + ca_cert_path: ssl.ca_cert_path, + client_cert_path: ssl.client_cert_path, + client_key_path: ssl.client_key_path, + insecure: ssl.insecure, + }); + + // Log storage configuration for debugging + if let Some(ref st) = acme_config.storage_type { + info!("ACME storage_type from config: '{}'", st); + } else { + warn!("ACME storage_type not set in config, will auto-detect from redis_url"); + } + if let Some(ref ru) = redis_url { + info!("ACME redis_url: '{}'", ru); + } else { + warn!("ACME redis_url not set"); + } + + let embedded_acme_config = EmbeddedAcmeConfig { + port: acme_config.port, + bind_ip: "127.0.0.1".to_string(), + upstreams_path, + email, + storage_path: PathBuf::from(&acme_config.storage_path), + storage_type: acme_config.storage_type.clone(), + development: acme_config.development, + redis_url, + redis_ssl, + }; + + // Clone config for HTTP server before moving it + let http_server_config = embedded_acme_config.clone(); + + let acme_server = EmbeddedAcmeServer::new(embedded_acme_config); + + // Initialize domain reader + if let Err(e) = acme_server.init_domain_reader().await { + error!("Failed to initialize ACME domain reader: {}", e); + return; + } + + // Start the HTTP server first (in background) so endpoint checks can succeed + tokio::spawn(async move { + let http_server = EmbeddedAcmeServer::new(http_server_config); + // Initialize domain reader for the HTTP server + if let Err(e) = http_server.init_domain_reader().await { + error!("Failed to initialize domain reader for HTTP server: {}", e); + return; + } + if let Err(e) = http_server.start_server().await { + error!("ACME server error: {}", e); + } + }); + + // Give the server a moment to start before processing certificates + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + + // Process certificates initially (endpoint check will retry if server not ready) + if let Err(e) = acme_server.process_certificates().await { + warn!("Failed to process initial certificates: {}", e); + } + }); + } + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + // Initialize Redis manager if Redis URL is provided (skip in agent mode) + let redis_initialized = if config.mode != "agent" && !config.redis.url.is_empty() { + match redis::RedisManager::init(&config.redis.url, config.redis.prefix.clone(), config.redis.ssl.as_ref()).await { + Ok(_) => true, + Err(e) => { + log::error!("Failed to initialize Redis manager: {}", e); + false + } + } + } else { + false + }; + + // Initialize worker manager + let (mut worker_manager, _worker_shutdown_rx) = worker::WorkerManager::new(); + + #[cfg(feature = "proxy")] + { + // Set ACME config for certificate worker to use (skip in agent mode) + if config.mode != "agent" { + worker::certificate::set_acme_config(config.acme.clone()); + } + + // Register certificate worker only if Redis was successfully initialized (skip in agent mode) + if redis_initialized && config.mode != "agent" { + // Parse proxy_certificates from config file (under pingora section) + let certificate_path = if let Some(config_path) = &args.config { + std::fs::read_to_string(config_path) + .ok() + .and_then(|content| serde_yaml::from_str::(&content).ok()) + .and_then(|yaml| { + // Try pingora.proxy_certificates first, then fallback to root level + yaml.get("pingora") + .and_then(|pingora| pingora.get("proxy_certificates")) + .or_else(|| yaml.get("proxy_certificates")) + .and_then(|v| v.as_str().map(|s| s.to_string())) + }) + .unwrap_or_else(|| "/tmp/synapse-certs".to_string()) + } else { + "/tmp/synapse-certs".to_string() + }; + + // Set proxy_certificates path for ACME certificate saving + crate::acme::set_proxy_certificates_path(Some(certificate_path.clone())); + + let refresh_interval = 30; // 30 seconds default refresh interval + let worker_config = worker::WorkerConfig { + name: "certificate".to_string(), + interval_secs: refresh_interval, + enabled: true, + }; + + let upstreams_path = config.pingora.upstreams_conf.clone(); + let certificate_worker = worker::certificate::CertificateWorker::new( + certificate_path.clone(), + upstreams_path, + refresh_interval + ); + + if let Err(e) = worker_manager.register_worker(worker_config, certificate_worker) { + log::error!("Failed to register certificate worker: {}", e); + } + } + } + + // Validate API key if provided + if !config.platform.base_url.is_empty() && !config.platform.api_key.is_empty() { + if let Err(e) = validate_api_key( + &config.platform.base_url, + &config.platform.api_key, + ).await { + log::error!("API key validation failed: {}", e); + return Err(anyhow::anyhow!("API key validation failed: {}", e)); + } + } + + // Initialize content scanning from CLI config (skip in agent mode) + let content_scanner_enabled = config.mode != "agent" && config.content_scanning.enabled; + if content_scanner_enabled { + let content_scanning_config = ContentScanningConfig { + enabled: config.content_scanning.enabled, + clamav_server: config.content_scanning.clamav_server.clone(), + max_file_size: config.content_scanning.max_file_size, + scan_content_types: config.content_scanning.scan_content_types.clone(), + skip_extensions: config.content_scanning.skip_extensions.clone(), + scan_expression: config.content_scanning.scan_expression.clone(), + }; + if let Err(e) = init_content_scanner(content_scanning_config) { + log::warn!("Failed to initialize content scanner: {}", e); + } + } + + // Initialize access log sender configuration + let log_sender_config = LogSenderConfig { + enabled: config.platform.log_sending_enabled, + base_url: config.platform.base_url.clone(), + api_key: config.platform.api_key.clone(), + batch_size_limit: 5000, // Default: 5000 logs per batch + batch_size_bytes: 5 * 1024 * 1024, // Default: 5MB + batch_timeout_secs: 10, // Default: 10 seconds + include_request_body: false, // Default: disabled + max_body_size: config.platform.max_body_size, + }; + set_log_sender_config(log_sender_config); + + // Register log sender worker if log sending is enabled + let log_sender_enabled = config.platform.log_sending_enabled && !config.platform.api_key.is_empty(); + if log_sender_enabled { + + let check_interval = 1; // Check every 1 second + let worker_config = worker::WorkerConfig { + name: "log_sender".to_string(), + interval_secs: check_interval, + enabled: true, + }; + + let log_sender_worker = worker::log::LogSenderWorker::new(check_interval); + + if let Err(e) = worker_manager.register_worker(worker_config, log_sender_worker) { + log::error!("Failed to register log sender worker: {}", e); + } + } + + // Determine if we have API key for full functionality + let has_api_key = !config.platform.api_key.is_empty(); + + + // Build list of interfaces to attach + if has_api_key && !config.platform.base_url.is_empty() { + // Skip WAF wirefilter initialization in agent mode (only access rules needed for XDP) + if !is_agent_mode { + if let Err(e) = init_config( + config.platform.base_url.clone(), + config.platform.api_key.clone(), + ) + .await + { + log::error!("Failed to initialize HTTP filter with config: {}", e); + log::error!("Aborting startup because WAF config could not be loaded"); + return Err(e); + } + waf_enabled = true; + } + + // Initialize threat intelligence client (skip in agent mode to save memory) + if !is_agent_mode { + threat_client_enabled = true; + let has_threat = !config.platform.threat.url.is_empty() || config.platform.threat.path.is_some(); + let has_geoip = !config.geoip.country.url.is_empty() + || !config.geoip.asn.url.is_empty() + || !config.geoip.city.url.is_empty() + || config.geoip.country.path.is_some() + || config.geoip.asn.path.is_some() + || config.geoip.city.path.is_some(); + + if has_threat || has_geoip { + if let Err(e) = threat::init_threat_client( + config.platform.threat.path.clone(), + config.geoip.country.path.clone(), + config.geoip.asn.path.clone(), + config.geoip.city.path.clone(), + ) + .await + { + log::warn!("Failed to initialize threat client: {}", e); + } else { + // Register Threat MMDB refresh worker if configured + let refresh_interval = config.platform.threat.refresh_secs.unwrap_or(300); + if !config.platform.threat.url.is_empty() && refresh_interval > 0 { + let worker_config = worker::WorkerConfig { + name: "threat_mmdb".to_string(), + interval_secs: refresh_interval, + enabled: true, + }; + let worker = worker::threat_mmdb::ThreatMmdbWorker::new( + refresh_interval, + config.platform.threat.url.clone(), + config.platform.threat.path.clone(), + config.platform.threat.headers.clone(), + config.platform.api_key.clone(), + ); + if let Err(e) = worker_manager.register_worker(worker_config, worker) { + log::error!("Failed to register threat MMDB worker: {}", e); + } + } + + // Register GeoIP MMDB refresh workers if configured + let refresh_interval = config.geoip.refresh_secs; + + // Country database worker + if !config.geoip.country.url.is_empty() && refresh_interval > 0 { + let worker_config = worker::WorkerConfig { + name: "geoip_country_mmdb".to_string(), + interval_secs: refresh_interval, + enabled: true, + }; + let worker = worker::geoip_mmdb::GeoipMmdbWorker::new( + refresh_interval, + config.geoip.country.url.clone(), + "".to_string(), // versions_url not used for geoip + config.geoip.country.path.clone(), + config.geoip.country.headers.clone(), + worker::geoip_mmdb::GeoipDatabaseType::Country, + ); + if let Err(e) = worker_manager.register_worker(worker_config, worker) { + log::error!("Failed to register GeoIP Country MMDB worker: {}", e); + } + } + + // ASN database worker + if !config.geoip.asn.url.is_empty() && refresh_interval > 0 { + let worker_config = worker::WorkerConfig { + name: "geoip_asn_mmdb".to_string(), + interval_secs: refresh_interval, + enabled: true, + }; + let worker = worker::geoip_mmdb::GeoipMmdbWorker::new( + refresh_interval, + config.geoip.asn.url.clone(), + "".to_string(), + config.geoip.asn.path.clone(), + config.geoip.asn.headers.clone(), + worker::geoip_mmdb::GeoipDatabaseType::Asn, + ); + if let Err(e) = worker_manager.register_worker(worker_config, worker) { + log::error!("Failed to register GeoIP ASN MMDB worker: {}", e); + } + } + + // City database worker + if !config.geoip.city.url.is_empty() && refresh_interval > 0 { + let worker_config = worker::WorkerConfig { + name: "geoip_city_mmdb".to_string(), + interval_secs: refresh_interval, + enabled: true, + }; + let worker = worker::geoip_mmdb::GeoipMmdbWorker::new( + refresh_interval, + config.geoip.city.url.clone(), + "".to_string(), + config.geoip.city.path.clone(), + config.geoip.city.headers.clone(), + worker::geoip_mmdb::GeoipDatabaseType::City, + ); + if let Err(e) = worker_manager.register_worker(worker_config, worker) { + log::error!("Failed to register GeoIP City MMDB worker: {}", e); + } + } + } + } + } + + // Initialize captcha client if configuration is provided (skip in agent mode) + if let (Some(site_key), Some(secret_key), Some(jwt_secret)) = ( + &config.platform.captcha.site_key, + &config.platform.captcha.secret_key, + &config.platform.captcha.jwt_secret, + ) { + let captcha_config = CaptchaConfig { + site_key: site_key.clone(), + secret_key: secret_key.clone(), + jwt_secret: jwt_secret.clone(), + provider: CaptchaProvider::from_str(&config.platform.captcha.provider) + .unwrap_or(CaptchaProvider::HCaptcha), + token_ttl_seconds: config.platform.captcha.token_ttl, + validation_cache_ttl_seconds: config.platform.captcha.cache_ttl, + }; + + if let Err(e) = init_captcha_client(captcha_config).await { + log::warn!("Failed to initialize captcha client: {}", e); + } else { + captcha_client_enabled = true; + start_cache_cleanup_task().await; + } + } + } else { + // LOCAL MODE: Load security rules from local file + let security_rules_path = args.security_rules_config.clone(); + + // Load and initialize from local file + match worker::config::load_config_from_file(&security_rules_path).await { + Ok(config_response) => { + // Store config globally for access rules + worker::config::set_global_config(config_response.config.clone()); + + // Initialize WAF with local rules (skip in agent mode to save memory) + if !is_agent_mode { + if let Err(e) = crate::waf::wirefilter::load_waf_rules(config_response.config.waf_rules.rules).await { + log::warn!("Failed to load WAF rules from local file: {}", e); + } else { + waf_enabled = true; + } + } + + // Update access rules in XDP if available + if !state.skels.is_empty() { + if let Err(e) = crate::access_rules::apply_rules_from_global_with_state(&state.skels, is_agent_mode) { + log::warn!("Failed to apply access rules from local file to XDP: {}", e); + } + } + } + Err(e) => { + log::error!("Failed to load security rules from local file: {}", e); + } + } + } + + // Register agent status worker (register + heartbeat) if unified event sending is enabled + if log_sender_enabled { + let hostname = std::env::var("HOSTNAME") + .ok() + .filter(|value| !value.trim().is_empty()) + .unwrap_or_else(|| gethostname::gethostname().to_string_lossy().into_owned()); + + let agent_name = std::env::var("AGENT_NAME") + .ok() + .filter(|value| !value.trim().is_empty()) + .unwrap_or_else(|| hostname.clone()); + + let workspace_id = config.platform.workspace_id.clone(); + let agent_id = build_agent_id(&agent_name, &workspace_id); + + let tags = std::env::var("AGENT_TAGS") + .ok() + .map(|value| { + value + .split(',') + .map(|tag| tag.trim().to_string()) + .filter(|tag| !tag.is_empty()) + .collect::>() + }) + .unwrap_or_default(); + + let mut capabilities = Vec::new(); + if log_sender_enabled { + capabilities.push("log_sender".to_string()); + } + if config.bpf_stats.enabled { + capabilities.push("bpf_stats".to_string()); + } + if config.bpf_stats.enable_dropped_ip_events { + capabilities.push("bpf_stats_dropped_ip_events".to_string()); + } + if config.tcp_fingerprint.enabled { + capabilities.push("tcp_fingerprint".to_string()); + } + if config.tcp_fingerprint.enable_fingerprint_events { + capabilities.push("tcp_fingerprint_events".to_string()); + } + if content_scanner_enabled { + capabilities.push("content_scanner".to_string()); + } + if waf_enabled { + capabilities.push("waf".to_string()); + } + if threat_client_enabled { + capabilities.push("threat_client".to_string()); + } + if captcha_client_enabled { + capabilities.push("captcha_client".to_string()); + } + if !config.network.disable_xdp { + capabilities.push("xdp".to_string()); + } + + let interfaces = if !config.network.ifaces.is_empty() { + config.network.ifaces.clone() + } else if !config.network.iface.is_empty() { + vec![config.network.iface.clone()] + } else { + Vec::new() + }; + + let ip_addresses = std::env::var("AGENT_IPS") + .or_else(|_| std::env::var("AGENT_IP_ADDRESSES")) + .ok() + .map(|value| { + value + .split(',') + .map(|ip| ip.trim().to_string()) + .filter(|ip| !ip.is_empty()) + .collect::>() + }) + .unwrap_or_default(); + + let mut metadata = HashMap::new(); + metadata.insert("os".to_string(), std::env::consts::OS.to_string()); + metadata.insert("arch".to_string(), std::env::consts::ARCH.to_string()); + metadata.insert("version".to_string(), env!("CARGO_PKG_VERSION").to_string()); + metadata.insert("mode".to_string(), config.mode.clone()); + metadata.insert("platform_base_url".to_string(), config.platform.base_url.clone()); + + let identity = AgentStatusIdentity { + agent_id, + agent_name, + hostname, + version: env!("CARGO_PKG_VERSION").to_string(), + mode: config.mode.clone(), + tags, + capabilities, + interfaces, + ip_addresses, + metadata, + started_at, + }; + + let heartbeat_secs = std::env::var("AGENT_HEARTBEAT_SECS") + .ok() + .and_then(|value| value.parse::().ok()) + .unwrap_or(30); + + let worker_config = worker::WorkerConfig { + name: "agent_status".to_string(), + interval_secs: heartbeat_secs, + enabled: true, + }; + + let agent_status_worker = AgentStatusWorker::new(identity, heartbeat_secs); + if let Err(e) = worker_manager.register_worker(worker_config, agent_status_worker) { + log::error!("Failed to register agent status worker: {}", e); + } + } + + // Access rules were already initialized after XDP attachment above + + // Register config worker to fetch and apply configuration periodically + if has_api_key && !config.platform.base_url.is_empty() { + let refresh_interval = 10; // 10 seconds config refresh interval + let worker_config = worker::WorkerConfig { + name: "config".to_string(), + interval_secs: refresh_interval, + enabled: true, + }; + + let config_worker = worker::config::ConfigWorker::new( + config.platform.base_url.clone(), + config.platform.api_key.clone(), + refresh_interval, + state.skels.clone(), + args.security_rules_config.clone(), + ).with_agent_mode(config.mode == "agent") + .with_nftables(state.nftables_firewall.clone()) + .with_iptables(state.iptables_firewall.clone()); + + if let Err(e) = worker_manager.register_worker(worker_config, config_worker) { + log::error!("Failed to register config worker: {}", e); + } + } else { + // In local mode, register config worker that loads from file at startup only + let refresh_interval = 10; + let worker_config = worker::WorkerConfig { + name: "config".to_string(), + interval_secs: refresh_interval, + enabled: true, + }; + + let config_worker = worker::config::ConfigWorker::new( + String::new(), + String::new(), + refresh_interval, + state.skels.clone(), + args.security_rules_config.clone(), + ).with_agent_mode(config.mode == "agent") + .with_nftables(state.nftables_firewall.clone()) + .with_iptables(state.iptables_firewall.clone()); + + if let Err(e) = worker_manager.register_worker(worker_config, config_worker) { + log::error!("Failed to register config worker: {}", e); + } + } + + // Start the old Pingora proxy system in a separate thread (non-blocking) + // Only start if mode is "proxy" (disabled in agent mode) + #[cfg(feature = "proxy")] + if config.mode == "proxy" { + let bpf_stats_config = config.bpf_stats.clone(); + let logging_config = config.logging.clone(); + let platform_config = config.platform.clone(); + let geoip_config = config.geoip.clone(); + let network_config = config.network.clone(); + let tcp_fingerprint_config = config.tcp_fingerprint.clone(); + let pingora_config = config.pingora.clone(); + std::thread::spawn(move || { + http_proxy::start::run_with_config(Some(crate::cli::Config { + mode: "proxy".to_string(), + worker_threads: None, + redis: Default::default(), + network: network_config, + platform: platform_config, + geoip: geoip_config, + content_scanning: Default::default(), + logging: logging_config, + bpf_stats: bpf_stats_config, + tcp_fingerprint: tcp_fingerprint_config, + daemon: Default::default(), + pingora: pingora_config, + acme: Default::default(), + })); + }); + } + + // Start BPF statistics logging task + let bpf_stats_handle = if config.bpf_stats.enabled && !state.skels.is_empty() { + let collector = state.bpf_stats_collector.clone(); + let log_interval = config.bpf_stats.log_interval_secs; + let shutdown = shutdown_rx.clone(); + Some(start_bpf_stats_logging(collector, log_interval, shutdown)) + } else { + None + }; + + // Start dropped IP events logging task + let dropped_ip_events_handle = if config.bpf_stats.enabled && + config.bpf_stats.enable_dropped_ip_events && + !state.skels.is_empty() { + let collector = state.bpf_stats_collector.clone(); + let log_interval = config.bpf_stats.dropped_ip_events_interval_secs; + let shutdown = shutdown_rx.clone(); + Some(start_dropped_ip_events_logging(collector, log_interval, shutdown)) + } else { + None + }; + + // Start TCP fingerprinting statistics logging task + let tcp_fingerprint_stats_handle = if config.tcp_fingerprint.enabled && !state.skels.is_empty() { + let collector = state.tcp_fingerprint_collector.clone(); + let log_interval = config.tcp_fingerprint.log_interval_secs; + let shutdown = shutdown_rx.clone(); + let state_clone = Arc::new(state.clone()); + Some(start_tcp_fingerprint_stats_logging(collector, log_interval, shutdown, state_clone)) + } else { + None + }; + + // Start TCP fingerprinting events logging task + let tcp_fingerprint_events_handle = if config.tcp_fingerprint.enabled && + config.tcp_fingerprint.enable_fingerprint_events && + !state.skels.is_empty() { + let collector = state.tcp_fingerprint_collector.clone(); + let log_interval = config.tcp_fingerprint.fingerprint_events_interval_secs; + let shutdown = shutdown_rx.clone(); + let state_clone = Arc::new(state.clone()); + Some(start_tcp_fingerprint_events_logging(collector, log_interval, shutdown, state_clone)) + } else { + None + }; + + // Log startup summary as JSON + let xdp_info: serde_json::Value = if xdp_modes.is_empty() { + serde_json::json!(null) + } else if xdp_modes.len() == 1 { + serde_json::json!(xdp_modes[0].1) + } else { + serde_json::json!(xdp_modes.iter().map(|(iface, mode)| { + serde_json::json!({ "interface": iface, "mode": mode }) + }).collect::>()) + }; + + let worker_threads = match config.worker_threads { + Some(threads) => serde_json::json!(threads), + None if is_agent_mode => serde_json::json!(0), + None => serde_json::json!(null), + }; + + let startup_info = serde_json::json!({ + "event": "startup_complete", + "mode": config.mode, + "runtime": if config.worker_threads == Some(0) || (config.worker_threads.is_none() && is_agent_mode) { + "single_thread" + } else { + "multi_thread" + }, + "worker_threads": worker_threads, + "interfaces": &iface_names, + "xdp_enabled": !config.network.disable_xdp && !state.skels.is_empty(), + "xdp_mode": xdp_info, + "features": { + "bpf_filtering": !state.skels.is_empty(), + "bpf_stats": config.bpf_stats.enabled && !state.skels.is_empty(), + "tcp_fingerprint": config.tcp_fingerprint.enabled && !state.skels.is_empty(), + "waf": waf_enabled, + "threat_intel": threat_client_enabled, + "captcha_server": captcha_server_enabled, + "captcha_client": captcha_client_enabled, + "content_scanner": content_scanner_enabled, + "redis": redis_initialized, + "log_sender": log_sender_enabled, + "acme": cfg!(feature = "proxy") && config.acme.enabled && !is_agent_mode, + "proxy": cfg!(feature = "proxy") && config.mode != "agent", + }, + "api_configured": has_api_key, + }); + log::info!("{}", startup_info); + + signal::ctrl_c().await?; + log::info!("Shutdown signal received, stopping servers..."); + let _ = shutdown_tx.send(true); + + if let Some(handle) = bpf_stats_handle + && let Err(err) = handle.await + { + log::error!("bpf-stats task join error: {err}"); + } + + if let Some(handle) = dropped_ip_events_handle + && let Err(err) = handle.await + { + log::error!("dropped-ip-events task join error: {err}"); + } + + if let Some(handle) = tcp_fingerprint_stats_handle + && let Err(err) = handle.await + { + log::error!("tcp-fingerprint-stats task join error: {err}"); + } + + if let Some(handle) = tcp_fingerprint_events_handle + && let Err(err) = handle.await + { + log::error!("tcp-fingerprint-events task join error: {err}"); + } + + // Shutdown all workers + worker_manager.shutdown(); + worker_manager.wait_for_all().await; + log::info!("Proceeding with cleanup..."); + + // Detach XDP programs from interfaces + #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] + if !ifindices.is_empty() { + log::info!("Detaching XDP programs from {} interfaces...", ifindices.len()); + for ifindex in ifindices { + if let Err(e) = bpf_detach_from_xdp(ifindex) { + log::error!("Failed to detach XDP from interface {}: {}", ifindex, e); + } + } + } + + // Cleanup nftables rules if using nftables backend + if let Some(ref nft_fw) = nftables_firewall { + log::info!("Cleaning up nftables firewall rules..."); + if let Ok(fw) = nft_fw.lock() { + if let Err(e) = fw.cleanup() { + log::error!("Failed to cleanup nftables rules: {}", e); + } + } + } + + // Cleanup iptables rules if using iptables backend + if let Some(ref ipt_fw) = iptables_firewall { + log::info!("Cleaning up iptables firewall rules..."); + if let Ok(fw) = ipt_fw.lock() { + if let Err(e) = fw.cleanup() { + log::error!("Failed to cleanup iptables rules: {}", e); + } + } + } + + log::info!("Shutdown complete, exiting..."); + + // Force exit immediately - cleanup is done + std::process::exit(0); +} + +fn read_env_non_empty(name: &str) -> Option { + std::env::var(name) + .ok() + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) +} + +fn build_agent_id(agent_name: &str, workspace_id: &str) -> String { + if read_env_non_empty("AGENT_ID").is_some() { + warn!("AGENT_ID is ignored; agent_id is derived from agent_name + workspace_id."); + } + + if workspace_id.trim().is_empty() { + warn!( + "WORKSPACE_ID not set; agent_id derived only from agent_name '{}'. Set WORKSPACE_ID \ +(or ARXIGNIS_WORKSPACE_ID) to avoid collisions across organizations.", + agent_name + ); + } + + let input = format!("{}:{}", workspace_id.trim(), agent_name.trim()); + let digest = Sha256::digest(input.as_bytes()); + format!("{:x}", digest) +} + +/// Start a background task that logs BPF statistics periodically +fn start_bpf_stats_logging( + collector: BpfStatsCollector, + log_interval_secs: u64, + mut shutdown: tokio::sync::watch::Receiver, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); + + loop { + tokio::select! { + _ = shutdown.changed() => { + if *shutdown.borrow() { break; } + } + _ = interval.tick() => { + if let Err(e) = collector.log_stats() { + log::warn!("Failed to log BPF statistics: {}", e); + } + } + } + } + + log::info!("BPF statistics logging task stopped"); + }) +} + +/// Start a background task that logs dropped IP events periodically +fn start_dropped_ip_events_logging( + collector: BpfStatsCollector, + log_interval_secs: u64, + mut shutdown: tokio::sync::watch::Receiver, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); + + loop { + tokio::select! { + _ = shutdown.changed() => { + if *shutdown.borrow() { break; } + } + _ = interval.tick() => { + if let Err(e) = collector.log_dropped_ip_events() { + log::warn!("Failed to log dropped IP events: {}", e); + } + } + } + } + + log::info!("Dropped IP events logging task stopped"); + }) +} + +/// Start a background task that logs TCP fingerprinting statistics periodically +fn start_tcp_fingerprint_stats_logging( + collector: TcpFingerprintCollector, + log_interval_secs: u64, + mut shutdown: tokio::sync::watch::Receiver, + _state: Arc, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); + + loop { + tokio::select! { + _ = shutdown.changed() => { + if *shutdown.borrow() { break; } + } + _ = interval.tick() => { + if let Err(e) = collector.log_stats() { + log::warn!("Failed to log TCP fingerprinting statistics: {}", e); + } + } + } + } + + log::info!("TCP fingerprinting statistics logging task stopped"); + }) +} + +/// Start a background task that logs TCP fingerprinting events periodically +fn start_tcp_fingerprint_events_logging( + collector: TcpFingerprintCollector, + log_interval_secs: u64, + mut shutdown: tokio::sync::watch::Receiver, + _state: Arc, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); + + loop { + tokio::select! { + _ = shutdown.changed() => { + if *shutdown.borrow() { break; } + } + _ = interval.tick() => { + if let Err(e) = collector.log_fingerprint_events() { + log::warn!("Failed to log TCP fingerprinting events: {}", e); + } + } + } + } + + log::info!("TCP fingerprinting events logging task stopped"); + }) +} diff --git a/src/proxy_protocol.rs b/src/proxy_protocol.rs index bf556be..d043f85 100644 --- a/src/proxy_protocol.rs +++ b/src/proxy_protocol.rs @@ -1,598 +1,598 @@ -use std::net::SocketAddr; -use std::time::Duration; -use anyhow::Result; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufReader}; -use tokio::time::timeout; -use proxy_protocol::{ProxyHeader, parse}; -use bytes::{Bytes, Buf}; - -/// Information extracted from PROXY protocol header -#[derive(Debug, Clone)] -pub struct ProxyInfo { - pub source_addr: SocketAddr, - pub dest_addr: SocketAddr, - pub version: ProxyVersion, -} - -#[derive(Debug, Clone)] -pub enum ProxyVersion { - V1, - V2, -} - -/// Parse PROXY protocol header from a buffered stream -/// Returns the proxy info and a buffered reader with unconsumed data preserved -pub async fn parse_proxy_protocol_buffered( - stream: R, - timeout_ms: u64, -) -> Result<(Option, BufReader>)> -where - R: AsyncRead + Unpin, -{ - log::trace!("Starting PROXY protocol parse with {}ms timeout", timeout_ms); - let mut reader = BufReader::with_capacity(512, stream); - let timeout_duration = Duration::from_millis(timeout_ms); - - let proxy_info = timeout(timeout_duration, async { - // Read enough bytes to detect PROXY protocol - // v1: up to 108 bytes (including \r\n) - // v2: 16 bytes header + variable length - let mut peek_buffer = vec![0u8; 232]; // Enough for v2 with reasonable extensions - - let mut total_read = 0; - while total_read < peek_buffer.len() { - let n = reader.read(&mut peek_buffer[total_read..]).await?; - if n == 0 { - log::trace!("PROXY protocol parse: EOF reached after reading {} bytes", total_read); - break; // EOF - } - total_read += n; - log::trace!("PROXY protocol parse: read {} bytes (total: {})", n, total_read); - - // Try parsing with what we have so far - let mut bytes = Bytes::copy_from_slice(&peek_buffer[..total_read]); - let bytes_before = bytes.remaining(); - - match parse(&mut bytes) { - Ok(header) => { - let consumed = bytes_before - bytes.remaining(); - log::trace!("PROXY protocol header successfully parsed: consumed {} bytes, total read: {}", consumed, total_read); - - // We successfully parsed a header - // The consumed bytes are the header, remaining bytes are application data - let info = header_to_proxy_info(header); - - // Return remaining bytes to the buffer - // We need to create a new reader that has the unconsumed data - if total_read > consumed { - // We read more than the header, need to preserve the extra bytes - let remaining = &peek_buffer[consumed..total_read]; - log::trace!("PROXY protocol: preserving {} extra bytes after header", remaining.len()); - let new_reader = create_reader_with_prefix(reader, remaining.to_vec()); - return Ok((info, new_reader)); - } - - // No extra bytes, but wrap in empty ChainedReader for type consistency - log::trace!("PROXY protocol: no extra bytes to preserve"); - let new_reader = create_reader_with_prefix(reader, Vec::new()); - return Ok((info, new_reader)); - } - Err(e) => { - log::trace!("PROXY protocol parse attempt failed: {}", e); - - // Check if this looks like a PROXY protocol header at all - if total_read >= 8 { - let prefix = &peek_buffer[..8]; - // v1 starts with "PROXY " - // v2 starts with specific signature - let is_v1_start = prefix.starts_with(b"PROXY "); - let is_v2_start = prefix.starts_with(&[0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51]); - - if !is_v1_start && !is_v2_start { - // Definitely not PROXY protocol, return all data to buffer - log::trace!("No PROXY protocol detected (first 8 bytes don't match signature), treating as plain connection. Preview: {:?}", - String::from_utf8_lossy(&prefix[..prefix.len().min(8)])); - let new_reader = create_reader_with_prefix(reader, peek_buffer[..total_read].to_vec()); - return Ok((None, new_reader)); - } - - log::trace!("PROXY protocol signature detected but incomplete, continuing to read..."); - } - - // Might be PROXY protocol but incomplete, keep reading - continue; - } - } - } - - // Reached EOF or buffer limit without parsing successfully - // Return all data to the buffer - if total_read > 0 { - log::trace!("PROXY protocol parse completed without header (read {} bytes), preserving all data", total_read); - let new_reader = create_reader_with_prefix(reader, peek_buffer[..total_read].to_vec()); - Ok((None, new_reader)) - } else { - log::trace!("PROXY protocol parse: no data read"); - let new_reader = create_reader_with_prefix(reader, Vec::new()); - Ok((None, new_reader)) - } - }).await; - - match proxy_info { - Ok(Ok((info, reader))) => { - if info.is_some() { - log::trace!("PROXY protocol parsing successful"); - } else { - log::trace!("PROXY protocol parsing completed: no header found, treating as plain connection"); - } - Ok((info, reader)) - }, - Ok(Err(e)) => { - log::warn!("PROXY protocol parsing error: {}", e); - Err(e) - }, - Err(_) => { - log::warn!("PROXY protocol parsing timeout after {}ms", timeout_ms); - Err(anyhow::anyhow!("PROXY protocol parsing timeout")) - }, - } -} - -/// Convert ProxyHeader to ProxyInfo -fn header_to_proxy_info(header: ProxyHeader) -> Option { - match header { - ProxyHeader::Version1 { addresses } => { - match addresses { - proxy_protocol::version1::ProxyAddresses::Ipv4 { source, destination } => { - Some(ProxyInfo { - source_addr: SocketAddr::V4(source), - dest_addr: SocketAddr::V4(destination), - version: ProxyVersion::V1, - }) - } - proxy_protocol::version1::ProxyAddresses::Ipv6 { source, destination } => { - Some(ProxyInfo { - source_addr: SocketAddr::V6(source), - dest_addr: SocketAddr::V6(destination), - version: ProxyVersion::V1, - }) - } - proxy_protocol::version1::ProxyAddresses::Unknown => None, - } - } - ProxyHeader::Version2 { addresses, .. } => { - match addresses { - proxy_protocol::version2::ProxyAddresses::Ipv4 { source, destination } => { - Some(ProxyInfo { - source_addr: SocketAddr::V4(source), - dest_addr: SocketAddr::V4(destination), - version: ProxyVersion::V2, - }) - } - proxy_protocol::version2::ProxyAddresses::Ipv6 { source, destination } => { - Some(ProxyInfo { - source_addr: SocketAddr::V6(source), - dest_addr: SocketAddr::V6(destination), - version: ProxyVersion::V2, - }) - } - proxy_protocol::version2::ProxyAddresses::Unspec => None, - proxy_protocol::version2::ProxyAddresses::Unix { .. } => None, - } - } - _ => None, - } -} - -/// Create a new BufReader with prefix data -fn create_reader_with_prefix( - inner: BufReader, - prefix: Vec, -) -> BufReader> { - let inner_stream = inner.into_inner(); - let chained = ChainedReader { - prefix: if prefix.is_empty() { None } else { Some(prefix) }, - prefix_pos: 0, - inner: inner_stream, - }; - BufReader::new(chained) -} - -/// A reader that first reads from prefix buffer, then from inner stream -pub struct ChainedReader { - prefix: Option>, - prefix_pos: usize, - inner: R, -} - -impl AsyncRead for ChainedReader { - fn poll_read( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - let this = &mut *self; - - // First drain the prefix - if let Some(ref prefix) = this.prefix { - if this.prefix_pos < prefix.len() { - let remaining = &prefix[this.prefix_pos..]; - let to_copy = remaining.len().min(buf.remaining()); - buf.put_slice(&remaining[..to_copy]); - this.prefix_pos += to_copy; - - if this.prefix_pos >= prefix.len() { - this.prefix = None; - } - - return std::task::Poll::Ready(Ok(())); - } - } - - // Then read from inner - std::pin::Pin::new(&mut this.inner).poll_read(cx, buf) - } -} - -/// A wrapper around a TCP stream that handles PROXY protocol parsing -/// -/// This wrapper transparently handles PROXY protocol v1 and v2 parsing and provides -/// access to the real client address information while preserving unconsumed data. -pub enum ProxyProtocolStream { - Plain { - inner: T, - proxy_info: Option, - }, - Buffered { - inner: BufReader, - proxy_info: Option, - }, - ChainedBuffered { - inner: BufReader>, - proxy_info: Option, - }, -} - -impl ProxyProtocolStream -where - T: AsyncRead + AsyncWrite + Unpin, -{ - pub async fn new( - stream: T, - proxy_protocol_enabled: bool, - timeout_ms: u64, - ) -> Result { - if proxy_protocol_enabled { - let (proxy_info, reader) = parse_proxy_protocol_buffered(stream, timeout_ms).await?; - - // Check if reader is BufReader or BufReader> - // This is determined by the parse function - // For now, we'll use the ChainedBuffered variant as it's more general - Ok(Self::ChainedBuffered { - inner: reader, - proxy_info, - }) - } else { - Ok(Self::Plain { - inner: stream, - proxy_info: None, - }) - } - } - - pub fn proxy_info(&self) -> Option<&ProxyInfo> { - match self { - Self::Plain { proxy_info, .. } => proxy_info.as_ref(), - Self::Buffered { proxy_info, .. } => proxy_info.as_ref(), - Self::ChainedBuffered { proxy_info, .. } => proxy_info.as_ref(), - } - } - - pub fn real_client_addr(&self) -> Option { - self.proxy_info().map(|info| info.source_addr) - } - - /// Returns true if PROXY protocol was detected and parsed successfully - pub fn has_proxy_info(&self) -> bool { - self.proxy_info().is_some() - } - - /// Extract the inner stream, consuming this wrapper - /// WARNING: This discards any buffered data that was read during PROXY protocol parsing! - /// Only use this if you're certain no data was read beyond the PROXY header. - pub fn inner(self) -> T { - match self { - Self::Plain { inner, .. } => inner, - Self::Buffered { inner, .. } => inner.into_inner(), - Self::ChainedBuffered { inner, .. } => { - let chained = inner.into_inner(); - chained.inner - } - } - } -} - -// Specific implementation for TcpStream to provide socket methods -impl ProxyProtocolStream { - /// Get the peer address from the underlying TCP stream - pub fn peer_addr(&self) -> std::io::Result { - match self { - Self::Plain { inner, .. } => inner.peer_addr(), - Self::Buffered { inner, .. } => inner.get_ref().peer_addr(), - Self::ChainedBuffered { inner, .. } => { - inner.get_ref().inner.peer_addr() - } - } - } - - /// Get the local address from the underlying TCP stream - pub fn local_addr(&self) -> std::io::Result { - match self { - Self::Plain { inner, .. } => inner.local_addr(), - Self::Buffered { inner, .. } => inner.get_ref().local_addr(), - Self::ChainedBuffered { inner, .. } => { - inner.get_ref().inner.local_addr() - } - } - } - - /// Shutdown the write half of the TCP stream - pub async fn shutdown(&mut self) -> std::io::Result<()> { - use tokio::io::AsyncWriteExt; - match self { - Self::Plain { inner, .. } => inner.shutdown().await, - Self::Buffered { inner, .. } => inner.get_mut().shutdown().await, - Self::ChainedBuffered { inner, .. } => { - inner.get_mut().inner.shutdown().await - } - } - } - - /// Write all bytes to the stream - pub async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { - use tokio::io::AsyncWriteExt; - match self { - Self::Plain { inner, .. } => inner.write_all(buf).await, - Self::Buffered { inner, .. } => inner.get_mut().write_all(buf).await, - Self::ChainedBuffered { inner, .. } => { - inner.get_mut().inner.write_all(buf).await - } - } - } -} - -impl AsyncRead for ProxyProtocolStream -where - T: AsyncRead + Unpin, -{ - fn poll_read( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - match &mut *self { - Self::Plain { inner, .. } => { - std::pin::Pin::new(inner).poll_read(cx, buf) - } - Self::Buffered { inner, .. } => { - std::pin::Pin::new(inner).poll_read(cx, buf) - } - Self::ChainedBuffered { inner, .. } => { - std::pin::Pin::new(inner).poll_read(cx, buf) - } - } - } -} - -impl AsyncWrite for ProxyProtocolStream -where - T: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_write( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - match &mut *self { - Self::Plain { inner, .. } => { - std::pin::Pin::new(inner).poll_write(cx, buf) - } - Self::Buffered { inner, .. } => { - std::pin::Pin::new(inner.get_mut()).poll_write(cx, buf) - } - Self::ChainedBuffered { inner, .. } => { - std::pin::Pin::new(&mut inner.get_mut().inner).poll_write(cx, buf) - } - } - } - - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - match &mut *self { - Self::Plain { inner, .. } => { - std::pin::Pin::new(inner).poll_flush(cx) - } - Self::Buffered { inner, .. } => { - std::pin::Pin::new(inner.get_mut()).poll_flush(cx) - } - Self::ChainedBuffered { inner, .. } => { - std::pin::Pin::new(&mut inner.get_mut().inner).poll_flush(cx) - } - } - } - - fn poll_shutdown( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - match &mut *self { - Self::Plain { inner, .. } => { - std::pin::Pin::new(inner).poll_shutdown(cx) - } - Self::Buffered { inner, .. } => { - std::pin::Pin::new(inner.get_mut()).poll_shutdown(cx) - } - Self::ChainedBuffered { inner, .. } => { - std::pin::Pin::new(&mut inner.get_mut().inner).poll_shutdown(cx) - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::io::Cursor; - use proxy_protocol::encode; - - #[tokio::test] - async fn test_parse_proxy_v1_ipv4() { - // Create PROXY protocol v1 IPv4 header - let header = "PROXY TCP4 192.168.1.100 192.168.1.200 12345 80\r\n"; - let stream = Cursor::new(header.as_bytes()); - - let (result, _reader) = parse_proxy_protocol_buffered(stream, 1000).await.unwrap(); - - assert!(result.is_some()); - let info = result.unwrap(); - assert_eq!(info.source_addr.ip(), "192.168.1.100".parse::().unwrap()); - assert_eq!(info.source_addr.port(), 12345); - assert_eq!(info.dest_addr.ip(), "192.168.1.200".parse::().unwrap()); - assert_eq!(info.dest_addr.port(), 80); - matches!(info.version, ProxyVersion::V1); - } - - #[tokio::test] - async fn test_parse_proxy_v1_ipv6() { - // Create PROXY protocol v1 IPv6 header - let header = "PROXY TCP6 2001:db8::1 2001:db8::2 12345 80\r\n"; - let stream = Cursor::new(header.as_bytes()); - - let (result, _reader) = parse_proxy_protocol_buffered(stream, 1000).await.unwrap(); - - assert!(result.is_some()); - let info = result.unwrap(); - assert_eq!(info.source_addr.ip(), "2001:db8::1".parse::().unwrap()); - assert_eq!(info.source_addr.port(), 12345); - assert_eq!(info.dest_addr.ip(), "2001:db8::2".parse::().unwrap()); - assert_eq!(info.dest_addr.port(), 80); - matches!(info.version, ProxyVersion::V1); - } - - #[tokio::test] - async fn test_parse_proxy_v2_ipv4() { - // Create PROXY protocol v2 IPv4 header using the crate's builder - let header = ProxyHeader::Version2 { - command: proxy_protocol::version2::ProxyCommand::Proxy, - transport_protocol: proxy_protocol::version2::ProxyTransportProtocol::Stream, - addresses: proxy_protocol::version2::ProxyAddresses::Ipv4 { - source: std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(192, 168, 1, 100), 12345), - destination: std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(192, 168, 1, 200), 80), - }, - extensions: vec![], - }; - - let data = encode(header).unwrap(); - let stream = Cursor::new(data); - - let (result, _reader) = parse_proxy_protocol_buffered(stream, 1000).await.unwrap(); - - assert!(result.is_some()); - let info = result.unwrap(); - assert_eq!(info.source_addr.ip(), "192.168.1.100".parse::().unwrap()); - assert_eq!(info.source_addr.port(), 12345); - assert_eq!(info.dest_addr.ip(), "192.168.1.200".parse::().unwrap()); - assert_eq!(info.dest_addr.port(), 80); - matches!(info.version, ProxyVersion::V2); - } - - #[tokio::test] - async fn test_parse_proxy_v2_ipv6() { - // Create PROXY protocol v2 IPv6 header using the crate's builder - let header = ProxyHeader::Version2 { - command: proxy_protocol::version2::ProxyCommand::Proxy, - transport_protocol: proxy_protocol::version2::ProxyTransportProtocol::Stream, - addresses: proxy_protocol::version2::ProxyAddresses::Ipv6 { - source: std::net::SocketAddrV6::new( - std::net::Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1), - 12345, 0, 0 - ), - destination: std::net::SocketAddrV6::new( - std::net::Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 2), - 80, 0, 0 - ), - }, - extensions: vec![], - }; - - let data = encode(header).unwrap(); - let stream = Cursor::new(data); - - let (result, _reader) = parse_proxy_protocol_buffered(stream, 1000).await.unwrap(); - - assert!(result.is_some()); - let info = result.unwrap(); - assert_eq!(info.source_addr.ip(), "2001:db8::1".parse::().unwrap()); - assert_eq!(info.source_addr.port(), 12345); - assert_eq!(info.dest_addr.ip(), "2001:db8::2".parse::().unwrap()); - assert_eq!(info.dest_addr.port(), 80); - matches!(info.version, ProxyVersion::V2); - } - - #[tokio::test] - async fn test_no_proxy_header() { - // Test with data that doesn't contain a PROXY protocol header - let data = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"; - let stream = Cursor::new(&data[..]); - - let (result, mut reader) = parse_proxy_protocol_buffered(stream, 1000).await.unwrap(); - assert!(result.is_none()); - - // Verify data is preserved in the reader - let mut buf = Vec::new(); - reader.read_to_end(&mut buf).await.unwrap(); - assert_eq!(&buf[..], &data[..]); - } - - #[tokio::test] - async fn test_proxy_protocol_stream_wrapper() { - // Test the wrapper functionality - let header = "PROXY TCP4 192.168.1.100 192.168.1.200 12345 80\r\n"; - let http_request = "GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"; - let data = format!("{}{}", header, http_request); - let stream = Cursor::new(data.into_bytes()); - - let mut wrapper = ProxyProtocolStream::new(stream, true, 1000).await.unwrap(); - - assert!(wrapper.has_proxy_info()); - assert_eq!(wrapper.real_client_addr().unwrap().ip(), "192.168.1.100".parse::().unwrap()); - assert_eq!(wrapper.real_client_addr().unwrap().port(), 12345); - - // Verify HTTP request is still readable - let mut buf = Vec::new(); - wrapper.read_to_end(&mut buf).await.unwrap(); - assert_eq!(&buf[..], http_request.as_bytes()); - } - - #[tokio::test] - async fn test_proxy_with_partial_http_data() { - // Test with PROXY header + partial HTTP data read in one go - let header = "PROXY TCP4 10.0.0.1 10.0.0.2 55555 443\r\n"; - let http_data = "POST /api HTTP/1.1\r\nContent-Length: 100\r\n\r\n"; - let full_data = format!("{}{}", header, http_data); - let stream = Cursor::new(full_data.into_bytes()); - - let (info, mut reader) = parse_proxy_protocol_buffered(stream, 1000).await.unwrap(); - - assert!(info.is_some()); - let proxy_info = info.unwrap(); - assert_eq!(proxy_info.source_addr.ip(), "10.0.0.1".parse::().unwrap()); - - // Verify HTTP data is preserved - let mut buf = Vec::new(); - reader.read_to_end(&mut buf).await.unwrap(); - assert_eq!(&buf[..], http_data.as_bytes()); - } -} +use std::net::SocketAddr; +use std::time::Duration; +use anyhow::Result; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufReader}; +use tokio::time::timeout; +use proxy_protocol::{ProxyHeader, parse}; +use bytes::{Bytes, Buf}; + +/// Information extracted from PROXY protocol header +#[derive(Debug, Clone)] +pub struct ProxyInfo { + pub source_addr: SocketAddr, + pub dest_addr: SocketAddr, + pub version: ProxyVersion, +} + +#[derive(Debug, Clone)] +pub enum ProxyVersion { + V1, + V2, +} + +/// Parse PROXY protocol header from a buffered stream +/// Returns the proxy info and a buffered reader with unconsumed data preserved +pub async fn parse_proxy_protocol_buffered( + stream: R, + timeout_ms: u64, +) -> Result<(Option, BufReader>)> +where + R: AsyncRead + Unpin, +{ + log::trace!("Starting PROXY protocol parse with {}ms timeout", timeout_ms); + let mut reader = BufReader::with_capacity(512, stream); + let timeout_duration = Duration::from_millis(timeout_ms); + + let proxy_info = timeout(timeout_duration, async { + // Read enough bytes to detect PROXY protocol + // v1: up to 108 bytes (including \r\n) + // v2: 16 bytes header + variable length + let mut peek_buffer = vec![0u8; 232]; // Enough for v2 with reasonable extensions + + let mut total_read = 0; + while total_read < peek_buffer.len() { + let n = reader.read(&mut peek_buffer[total_read..]).await?; + if n == 0 { + log::trace!("PROXY protocol parse: EOF reached after reading {} bytes", total_read); + break; // EOF + } + total_read += n; + log::trace!("PROXY protocol parse: read {} bytes (total: {})", n, total_read); + + // Try parsing with what we have so far + let mut bytes = Bytes::copy_from_slice(&peek_buffer[..total_read]); + let bytes_before = bytes.remaining(); + + match parse(&mut bytes) { + Ok(header) => { + let consumed = bytes_before - bytes.remaining(); + log::trace!("PROXY protocol header successfully parsed: consumed {} bytes, total read: {}", consumed, total_read); + + // We successfully parsed a header + // The consumed bytes are the header, remaining bytes are application data + let info = header_to_proxy_info(header); + + // Return remaining bytes to the buffer + // We need to create a new reader that has the unconsumed data + if total_read > consumed { + // We read more than the header, need to preserve the extra bytes + let remaining = &peek_buffer[consumed..total_read]; + log::trace!("PROXY protocol: preserving {} extra bytes after header", remaining.len()); + let new_reader = create_reader_with_prefix(reader, remaining.to_vec()); + return Ok((info, new_reader)); + } + + // No extra bytes, but wrap in empty ChainedReader for type consistency + log::trace!("PROXY protocol: no extra bytes to preserve"); + let new_reader = create_reader_with_prefix(reader, Vec::new()); + return Ok((info, new_reader)); + } + Err(e) => { + log::trace!("PROXY protocol parse attempt failed: {}", e); + + // Check if this looks like a PROXY protocol header at all + if total_read >= 8 { + let prefix = &peek_buffer[..8]; + // v1 starts with "PROXY " + // v2 starts with specific signature + let is_v1_start = prefix.starts_with(b"PROXY "); + let is_v2_start = prefix.starts_with(&[0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51]); + + if !is_v1_start && !is_v2_start { + // Definitely not PROXY protocol, return all data to buffer + log::trace!("No PROXY protocol detected (first 8 bytes don't match signature), treating as plain connection. Preview: {:?}", + String::from_utf8_lossy(&prefix[..prefix.len().min(8)])); + let new_reader = create_reader_with_prefix(reader, peek_buffer[..total_read].to_vec()); + return Ok((None, new_reader)); + } + + log::trace!("PROXY protocol signature detected but incomplete, continuing to read..."); + } + + // Might be PROXY protocol but incomplete, keep reading + continue; + } + } + } + + // Reached EOF or buffer limit without parsing successfully + // Return all data to the buffer + if total_read > 0 { + log::trace!("PROXY protocol parse completed without header (read {} bytes), preserving all data", total_read); + let new_reader = create_reader_with_prefix(reader, peek_buffer[..total_read].to_vec()); + Ok((None, new_reader)) + } else { + log::trace!("PROXY protocol parse: no data read"); + let new_reader = create_reader_with_prefix(reader, Vec::new()); + Ok((None, new_reader)) + } + }).await; + + match proxy_info { + Ok(Ok((info, reader))) => { + if info.is_some() { + log::trace!("PROXY protocol parsing successful"); + } else { + log::trace!("PROXY protocol parsing completed: no header found, treating as plain connection"); + } + Ok((info, reader)) + }, + Ok(Err(e)) => { + log::warn!("PROXY protocol parsing error: {}", e); + Err(e) + }, + Err(_) => { + log::warn!("PROXY protocol parsing timeout after {}ms", timeout_ms); + Err(anyhow::anyhow!("PROXY protocol parsing timeout")) + }, + } +} + +/// Convert ProxyHeader to ProxyInfo +fn header_to_proxy_info(header: ProxyHeader) -> Option { + match header { + ProxyHeader::Version1 { addresses } => { + match addresses { + proxy_protocol::version1::ProxyAddresses::Ipv4 { source, destination } => { + Some(ProxyInfo { + source_addr: SocketAddr::V4(source), + dest_addr: SocketAddr::V4(destination), + version: ProxyVersion::V1, + }) + } + proxy_protocol::version1::ProxyAddresses::Ipv6 { source, destination } => { + Some(ProxyInfo { + source_addr: SocketAddr::V6(source), + dest_addr: SocketAddr::V6(destination), + version: ProxyVersion::V1, + }) + } + proxy_protocol::version1::ProxyAddresses::Unknown => None, + } + } + ProxyHeader::Version2 { addresses, .. } => { + match addresses { + proxy_protocol::version2::ProxyAddresses::Ipv4 { source, destination } => { + Some(ProxyInfo { + source_addr: SocketAddr::V4(source), + dest_addr: SocketAddr::V4(destination), + version: ProxyVersion::V2, + }) + } + proxy_protocol::version2::ProxyAddresses::Ipv6 { source, destination } => { + Some(ProxyInfo { + source_addr: SocketAddr::V6(source), + dest_addr: SocketAddr::V6(destination), + version: ProxyVersion::V2, + }) + } + proxy_protocol::version2::ProxyAddresses::Unspec => None, + proxy_protocol::version2::ProxyAddresses::Unix { .. } => None, + } + } + _ => None, + } +} + +/// Create a new BufReader with prefix data +fn create_reader_with_prefix( + inner: BufReader, + prefix: Vec, +) -> BufReader> { + let inner_stream = inner.into_inner(); + let chained = ChainedReader { + prefix: if prefix.is_empty() { None } else { Some(prefix) }, + prefix_pos: 0, + inner: inner_stream, + }; + BufReader::new(chained) +} + +/// A reader that first reads from prefix buffer, then from inner stream +pub struct ChainedReader { + prefix: Option>, + prefix_pos: usize, + inner: R, +} + +impl AsyncRead for ChainedReader { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + let this = &mut *self; + + // First drain the prefix + if let Some(ref prefix) = this.prefix { + if this.prefix_pos < prefix.len() { + let remaining = &prefix[this.prefix_pos..]; + let to_copy = remaining.len().min(buf.remaining()); + buf.put_slice(&remaining[..to_copy]); + this.prefix_pos += to_copy; + + if this.prefix_pos >= prefix.len() { + this.prefix = None; + } + + return std::task::Poll::Ready(Ok(())); + } + } + + // Then read from inner + std::pin::Pin::new(&mut this.inner).poll_read(cx, buf) + } +} + +/// A wrapper around a TCP stream that handles PROXY protocol parsing +/// +/// This wrapper transparently handles PROXY protocol v1 and v2 parsing and provides +/// access to the real client address information while preserving unconsumed data. +pub enum ProxyProtocolStream { + Plain { + inner: T, + proxy_info: Option, + }, + Buffered { + inner: BufReader, + proxy_info: Option, + }, + ChainedBuffered { + inner: BufReader>, + proxy_info: Option, + }, +} + +impl ProxyProtocolStream +where + T: AsyncRead + AsyncWrite + Unpin, +{ + pub async fn new( + stream: T, + proxy_protocol_enabled: bool, + timeout_ms: u64, + ) -> Result { + if proxy_protocol_enabled { + let (proxy_info, reader) = parse_proxy_protocol_buffered(stream, timeout_ms).await?; + + // Check if reader is BufReader or BufReader> + // This is determined by the parse function + // For now, we'll use the ChainedBuffered variant as it's more general + Ok(Self::ChainedBuffered { + inner: reader, + proxy_info, + }) + } else { + Ok(Self::Plain { + inner: stream, + proxy_info: None, + }) + } + } + + pub fn proxy_info(&self) -> Option<&ProxyInfo> { + match self { + Self::Plain { proxy_info, .. } => proxy_info.as_ref(), + Self::Buffered { proxy_info, .. } => proxy_info.as_ref(), + Self::ChainedBuffered { proxy_info, .. } => proxy_info.as_ref(), + } + } + + pub fn real_client_addr(&self) -> Option { + self.proxy_info().map(|info| info.source_addr) + } + + /// Returns true if PROXY protocol was detected and parsed successfully + pub fn has_proxy_info(&self) -> bool { + self.proxy_info().is_some() + } + + /// Extract the inner stream, consuming this wrapper + /// WARNING: This discards any buffered data that was read during PROXY protocol parsing! + /// Only use this if you're certain no data was read beyond the PROXY header. + pub fn inner(self) -> T { + match self { + Self::Plain { inner, .. } => inner, + Self::Buffered { inner, .. } => inner.into_inner(), + Self::ChainedBuffered { inner, .. } => { + let chained = inner.into_inner(); + chained.inner + } + } + } +} + +// Specific implementation for TcpStream to provide socket methods +impl ProxyProtocolStream { + /// Get the peer address from the underlying TCP stream + pub fn peer_addr(&self) -> std::io::Result { + match self { + Self::Plain { inner, .. } => inner.peer_addr(), + Self::Buffered { inner, .. } => inner.get_ref().peer_addr(), + Self::ChainedBuffered { inner, .. } => { + inner.get_ref().inner.peer_addr() + } + } + } + + /// Get the local address from the underlying TCP stream + pub fn local_addr(&self) -> std::io::Result { + match self { + Self::Plain { inner, .. } => inner.local_addr(), + Self::Buffered { inner, .. } => inner.get_ref().local_addr(), + Self::ChainedBuffered { inner, .. } => { + inner.get_ref().inner.local_addr() + } + } + } + + /// Shutdown the write half of the TCP stream + pub async fn shutdown(&mut self) -> std::io::Result<()> { + use tokio::io::AsyncWriteExt; + match self { + Self::Plain { inner, .. } => inner.shutdown().await, + Self::Buffered { inner, .. } => inner.get_mut().shutdown().await, + Self::ChainedBuffered { inner, .. } => { + inner.get_mut().inner.shutdown().await + } + } + } + + /// Write all bytes to the stream + pub async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { + use tokio::io::AsyncWriteExt; + match self { + Self::Plain { inner, .. } => inner.write_all(buf).await, + Self::Buffered { inner, .. } => inner.get_mut().write_all(buf).await, + Self::ChainedBuffered { inner, .. } => { + inner.get_mut().inner.write_all(buf).await + } + } + } +} + +impl AsyncRead for ProxyProtocolStream +where + T: AsyncRead + Unpin, +{ + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + match &mut *self { + Self::Plain { inner, .. } => { + std::pin::Pin::new(inner).poll_read(cx, buf) + } + Self::Buffered { inner, .. } => { + std::pin::Pin::new(inner).poll_read(cx, buf) + } + Self::ChainedBuffered { inner, .. } => { + std::pin::Pin::new(inner).poll_read(cx, buf) + } + } + } +} + +impl AsyncWrite for ProxyProtocolStream +where + T: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + match &mut *self { + Self::Plain { inner, .. } => { + std::pin::Pin::new(inner).poll_write(cx, buf) + } + Self::Buffered { inner, .. } => { + std::pin::Pin::new(inner.get_mut()).poll_write(cx, buf) + } + Self::ChainedBuffered { inner, .. } => { + std::pin::Pin::new(&mut inner.get_mut().inner).poll_write(cx, buf) + } + } + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match &mut *self { + Self::Plain { inner, .. } => { + std::pin::Pin::new(inner).poll_flush(cx) + } + Self::Buffered { inner, .. } => { + std::pin::Pin::new(inner.get_mut()).poll_flush(cx) + } + Self::ChainedBuffered { inner, .. } => { + std::pin::Pin::new(&mut inner.get_mut().inner).poll_flush(cx) + } + } + } + + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match &mut *self { + Self::Plain { inner, .. } => { + std::pin::Pin::new(inner).poll_shutdown(cx) + } + Self::Buffered { inner, .. } => { + std::pin::Pin::new(inner.get_mut()).poll_shutdown(cx) + } + Self::ChainedBuffered { inner, .. } => { + std::pin::Pin::new(&mut inner.get_mut().inner).poll_shutdown(cx) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Cursor; + use proxy_protocol::encode; + + #[tokio::test] + async fn test_parse_proxy_v1_ipv4() { + // Create PROXY protocol v1 IPv4 header + let header = "PROXY TCP4 192.168.1.100 192.168.1.200 12345 80\r\n"; + let stream = Cursor::new(header.as_bytes()); + + let (result, _reader) = parse_proxy_protocol_buffered(stream, 1000).await.unwrap(); + + assert!(result.is_some()); + let info = result.unwrap(); + assert_eq!(info.source_addr.ip(), "192.168.1.100".parse::().unwrap()); + assert_eq!(info.source_addr.port(), 12345); + assert_eq!(info.dest_addr.ip(), "192.168.1.200".parse::().unwrap()); + assert_eq!(info.dest_addr.port(), 80); + matches!(info.version, ProxyVersion::V1); + } + + #[tokio::test] + async fn test_parse_proxy_v1_ipv6() { + // Create PROXY protocol v1 IPv6 header + let header = "PROXY TCP6 2001:db8::1 2001:db8::2 12345 80\r\n"; + let stream = Cursor::new(header.as_bytes()); + + let (result, _reader) = parse_proxy_protocol_buffered(stream, 1000).await.unwrap(); + + assert!(result.is_some()); + let info = result.unwrap(); + assert_eq!(info.source_addr.ip(), "2001:db8::1".parse::().unwrap()); + assert_eq!(info.source_addr.port(), 12345); + assert_eq!(info.dest_addr.ip(), "2001:db8::2".parse::().unwrap()); + assert_eq!(info.dest_addr.port(), 80); + matches!(info.version, ProxyVersion::V1); + } + + #[tokio::test] + async fn test_parse_proxy_v2_ipv4() { + // Create PROXY protocol v2 IPv4 header using the crate's builder + let header = ProxyHeader::Version2 { + command: proxy_protocol::version2::ProxyCommand::Proxy, + transport_protocol: proxy_protocol::version2::ProxyTransportProtocol::Stream, + addresses: proxy_protocol::version2::ProxyAddresses::Ipv4 { + source: std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(192, 168, 1, 100), 12345), + destination: std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(192, 168, 1, 200), 80), + }, + extensions: vec![], + }; + + let data = encode(header).unwrap(); + let stream = Cursor::new(data); + + let (result, _reader) = parse_proxy_protocol_buffered(stream, 1000).await.unwrap(); + + assert!(result.is_some()); + let info = result.unwrap(); + assert_eq!(info.source_addr.ip(), "192.168.1.100".parse::().unwrap()); + assert_eq!(info.source_addr.port(), 12345); + assert_eq!(info.dest_addr.ip(), "192.168.1.200".parse::().unwrap()); + assert_eq!(info.dest_addr.port(), 80); + matches!(info.version, ProxyVersion::V2); + } + + #[tokio::test] + async fn test_parse_proxy_v2_ipv6() { + // Create PROXY protocol v2 IPv6 header using the crate's builder + let header = ProxyHeader::Version2 { + command: proxy_protocol::version2::ProxyCommand::Proxy, + transport_protocol: proxy_protocol::version2::ProxyTransportProtocol::Stream, + addresses: proxy_protocol::version2::ProxyAddresses::Ipv6 { + source: std::net::SocketAddrV6::new( + std::net::Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1), + 12345, 0, 0 + ), + destination: std::net::SocketAddrV6::new( + std::net::Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 2), + 80, 0, 0 + ), + }, + extensions: vec![], + }; + + let data = encode(header).unwrap(); + let stream = Cursor::new(data); + + let (result, _reader) = parse_proxy_protocol_buffered(stream, 1000).await.unwrap(); + + assert!(result.is_some()); + let info = result.unwrap(); + assert_eq!(info.source_addr.ip(), "2001:db8::1".parse::().unwrap()); + assert_eq!(info.source_addr.port(), 12345); + assert_eq!(info.dest_addr.ip(), "2001:db8::2".parse::().unwrap()); + assert_eq!(info.dest_addr.port(), 80); + matches!(info.version, ProxyVersion::V2); + } + + #[tokio::test] + async fn test_no_proxy_header() { + // Test with data that doesn't contain a PROXY protocol header + let data = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"; + let stream = Cursor::new(&data[..]); + + let (result, mut reader) = parse_proxy_protocol_buffered(stream, 1000).await.unwrap(); + assert!(result.is_none()); + + // Verify data is preserved in the reader + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + assert_eq!(&buf[..], &data[..]); + } + + #[tokio::test] + async fn test_proxy_protocol_stream_wrapper() { + // Test the wrapper functionality + let header = "PROXY TCP4 192.168.1.100 192.168.1.200 12345 80\r\n"; + let http_request = "GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"; + let data = format!("{}{}", header, http_request); + let stream = Cursor::new(data.into_bytes()); + + let mut wrapper = ProxyProtocolStream::new(stream, true, 1000).await.unwrap(); + + assert!(wrapper.has_proxy_info()); + assert_eq!(wrapper.real_client_addr().unwrap().ip(), "192.168.1.100".parse::().unwrap()); + assert_eq!(wrapper.real_client_addr().unwrap().port(), 12345); + + // Verify HTTP request is still readable + let mut buf = Vec::new(); + wrapper.read_to_end(&mut buf).await.unwrap(); + assert_eq!(&buf[..], http_request.as_bytes()); + } + + #[tokio::test] + async fn test_proxy_with_partial_http_data() { + // Test with PROXY header + partial HTTP data read in one go + let header = "PROXY TCP4 10.0.0.1 10.0.0.2 55555 443\r\n"; + let http_data = "POST /api HTTP/1.1\r\nContent-Length: 100\r\n\r\n"; + let full_data = format!("{}{}", header, http_data); + let stream = Cursor::new(full_data.into_bytes()); + + let (info, mut reader) = parse_proxy_protocol_buffered(stream, 1000).await.unwrap(); + + assert!(info.is_some()); + let proxy_info = info.unwrap(); + assert_eq!(proxy_info.source_addr.ip(), "10.0.0.1".parse::().unwrap()); + + // Verify HTTP data is preserved + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + assert_eq!(&buf[..], http_data.as_bytes()); + } +} diff --git a/src/threat/mod.rs b/src/threat/mod.rs index 11d0d2f..3300b01 100644 --- a/src/threat/mod.rs +++ b/src/threat/mod.rs @@ -1,1205 +1,1205 @@ -use std::{net::IpAddr, path::PathBuf, sync::Arc, time::Duration}; - -use anyhow::{anyhow, Context, Result}; -use chrono::{DateTime, Utc}; -use maxminddb::{geoip2, MaxMindDbError, Reader}; -use memmap2::MmapOptions; -use std::fs::File; -#[cfg(feature = "proxy")] -use pingora_memory_cache::MemoryCache; -#[cfg(not(feature = "proxy"))] -use dashmap::DashMap; -#[cfg(not(feature = "proxy"))] -use std::time::Instant; -use serde::{Deserialize, Deserializer, Serialize}; -use tokio::sync::{OnceCell, RwLock}; - -/// Custom deserializer for optional datetime fields that can be empty strings or missing -fn deserialize_optional_datetime<'de, D>(deserializer: D) -> Result>, D::Error> -where - D: Deserializer<'de>, -{ - match Option::::deserialize(deserializer)? { - Some(s) => { - if s.is_empty() { - Ok(None) - } else { - DateTime::parse_from_rfc3339(&s) - .map(|dt| Some(dt.with_timezone(&Utc))) - .map_err(serde::de::Error::custom) - } - } - None => Ok(None), - } -} - -/// Threat intelligence response (REST shape) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ThreatResponse { - pub schema_version: String, - pub tenant_id: String, - pub ip: String, - pub intel: ThreatIntel, - pub context: ThreatContext, - pub advice: String, - pub ttl_s: u64, - pub generated_at: DateTime, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ThreatIntel { - pub score: u32, - pub confidence: f64, - pub score_version: String, - pub categories: Vec, - pub tags: Vec, - #[serde(deserialize_with = "deserialize_optional_datetime")] - pub first_seen: Option>, - #[serde(deserialize_with = "deserialize_optional_datetime")] - pub last_seen: Option>, - pub source_count: u32, - pub reason_code: String, - pub reason_summary: String, - pub rule_id: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ThreatContext { - pub asn: u32, - pub org: String, - pub ip_version: u8, - pub geo: GeoInfo, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GeoInfo { - pub country: String, - pub iso_code: String, - #[serde(rename = "asniso_code")] - pub asn_iso_code: String, -} - -/// WAF fields extracted from threat data -#[derive(Debug, Clone)] -pub struct WafFields { - pub ip_src_country: String, - pub ip_src_asn: u32, - pub ip_src_asn_org: String, - pub ip_src_asn_country: String, - pub threat_score: u32, - pub threat_advice: String, -} - -impl From<&ThreatResponse> for WafFields { - fn from(threat: &ThreatResponse) -> Self { - Self { - ip_src_country: threat.context.geo.iso_code.clone(), - ip_src_asn: threat.context.asn, - ip_src_asn_org: threat.context.org.clone(), - ip_src_asn_country: threat.context.geo.asn_iso_code.clone(), - threat_score: threat.intel.score, - threat_advice: threat.advice.clone(), - } - } -} - -#[derive(Clone, Copy)] -pub struct CacheStatus { - hit: bool, -} - -impl CacheStatus { - pub fn is_hit(&self) -> bool { - self.hit - } -} - -#[cfg(feature = "proxy")] -struct ThreatCache { - inner: MemoryCache, -} - -#[cfg(feature = "proxy")] -impl ThreatCache { - fn new(max_entries: usize) -> Self { - Self { - inner: MemoryCache::new(max_entries), - } - } - - fn get(&self, key: &str) -> (Option, CacheStatus) { - let (value, status) = self.inner.get(key); - (value, CacheStatus { hit: status.is_hit() }) - } - - fn put(&self, key: &str, value: ThreatResponse, ttl: Option) { - self.inner.put(key, value, ttl); - } -} - -#[cfg(not(feature = "proxy"))] -struct ThreatCacheEntry { - value: ThreatResponse, - expires_at: Instant, -} - -#[cfg(not(feature = "proxy"))] -struct ThreatCache { - entries: DashMap, - max_entries: usize, -} - -#[cfg(not(feature = "proxy"))] -impl ThreatCache { - fn new(max_entries: usize) -> Self { - Self { - entries: DashMap::new(), - max_entries, - } - } - - fn get(&self, key: &str) -> (Option, CacheStatus) { - if let Some(entry) = self.entries.get(key) { - if Instant::now() < entry.expires_at { - return (Some(entry.value.clone()), CacheStatus { hit: true }); - } - } - self.entries.remove(key); - (None, CacheStatus { hit: false }) - } - - fn put(&self, key: &str, value: ThreatResponse, ttl: Option) { - let ttl = match ttl { - Some(t) if !t.is_zero() => t, - _ => return, - }; - - if self.entries.len() >= self.max_entries { - if let Some(oldest_key) = self.entries.iter().next().map(|entry| entry.key().clone()) { - self.entries.remove(&oldest_key); - } - } - - let expires_at = Instant::now() - .checked_add(ttl) - .unwrap_or_else(Instant::now); - - self.entries.insert( - key.to_string(), - ThreatCacheEntry { - value, - expires_at, - }, - ); - } -} - -#[cfg(feature = "proxy")] -pub struct VersionCache { - inner: MemoryCache, -} - -#[cfg(feature = "proxy")] -impl VersionCache { - fn new(max_entries: usize) -> Self { - Self { - inner: MemoryCache::new(max_entries), - } - } - - pub fn get(&self, key: &str) -> (Option, CacheStatus) { - let (value, status) = self.inner.get(key); - (value, CacheStatus { hit: status.is_hit() }) - } - - pub fn put(&self, key: &str, value: String, ttl: Option) { - self.inner.put(key, value, ttl); - } -} - -#[cfg(not(feature = "proxy"))] -struct VersionCacheEntry { - value: String, - expires_at: Option, -} - -#[cfg(not(feature = "proxy"))] -pub struct VersionCache { - entries: DashMap, - max_entries: usize, -} - -#[cfg(not(feature = "proxy"))] -impl VersionCache { - fn new(max_entries: usize) -> Self { - Self { - entries: DashMap::new(), - max_entries, - } - } - - pub fn get(&self, key: &str) -> (Option, CacheStatus) { - if let Some(entry) = self.entries.get(key) { - if let Some(expires_at) = entry.expires_at { - if Instant::now() >= expires_at { - self.entries.remove(key); - return (None, CacheStatus { hit: false }); - } - } - return (Some(entry.value.clone()), CacheStatus { hit: true }); - } - (None, CacheStatus { hit: false }) - } - - pub fn put(&self, key: &str, value: String, ttl: Option) { - let expires_at = match ttl { - Some(t) if !t.is_zero() => Instant::now().checked_add(t), - Some(_) => return, - None => None, - }; - - if self.entries.len() >= self.max_entries { - if let Some(oldest_key) = self.entries.iter().next().map(|entry| entry.key().clone()) { - self.entries.remove(&oldest_key); - } - } - - self.entries.insert( - key.to_string(), - VersionCacheEntry { value, expires_at }, - ); - } -} - -/// Threat intel client: Threat MMDB first, then GeoIP MMDB fallback, with in-memory cache -pub struct ThreatClient { - threat_mmdb_path: Option, - geoip_country_path: Option, - geoip_asn_path: Option, - geoip_city_path: Option, - threat_reader: RwLock>>>, - geoip_country_reader: RwLock>>>, - geoip_asn_reader: RwLock>>>, - geoip_city_reader: RwLock>>>, - cache: Arc, -} - -/// Default cache size for threat response cache (10,000 entries) -pub const DEFAULT_THREAT_CACHE_SIZE: usize = 10_000; - -/// Smaller cache size suitable for agent mode or low-memory environments -pub const SMALL_THREAT_CACHE_SIZE: usize = 1_000; - -impl ThreatClient { - /// Create a new threat client with default cache size (10,000 entries) - pub fn new( - threat_mmdb_path: Option, - geoip_country_path: Option, - geoip_asn_path: Option, - geoip_city_path: Option, - ) -> Self { - Self::with_cache_size( - threat_mmdb_path, - geoip_country_path, - geoip_asn_path, - geoip_city_path, - DEFAULT_THREAT_CACHE_SIZE, - ) - } - - /// Create a new threat client with configurable cache size - /// Use SMALL_THREAT_CACHE_SIZE for agent mode or low-memory environments - pub fn with_cache_size( - threat_mmdb_path: Option, - geoip_country_path: Option, - geoip_asn_path: Option, - geoip_city_path: Option, - cache_size: usize, - ) -> Self { - Self { - threat_mmdb_path, - geoip_country_path, - geoip_asn_path, - geoip_city_path, - threat_reader: RwLock::new(None), - geoip_country_reader: RwLock::new(None), - geoip_asn_reader: RwLock::new(None), - geoip_city_reader: RwLock::new(None), - cache: Arc::new(ThreatCache::new(cache_size)), - } - } - - pub async fn refresh_threat(&self) -> Result<()> { - self.refresh_threat_reader().await.map(|_| ()) - } - - pub async fn refresh_geoip(&self) -> Result<()> { - // Refresh all geoip databases - if let Err(e) = self.refresh_geoip_country_reader().await { - log::warn!("Failed to refresh GeoIP Country database: {}", e); - } - if let Err(e) = self.refresh_geoip_asn_reader().await { - log::warn!("Failed to refresh GeoIP ASN database: {}", e); - } - if let Err(e) = self.refresh_geoip_city_reader().await { - log::warn!("Failed to refresh GeoIP City database: {}", e); - } - Ok(()) - } - - /// Get threat intelligence for an IP address with caching - /// Priority: Cache → Threat MMDB → GeoIP MMDB (REST API disabled) - pub async fn get_threat_intel(&self, ip: &str) -> Result> { - // L1 cache - let (cached, status) = self.cache.get(ip); - if let Some(data) = cached { - if status.is_hit() { - log::debug!("Threat data for {} found in cache", ip); - return Ok(Some(data)); - } - } - - let ip_addr: IpAddr = match ip.parse() { - Ok(v) => v, - Err(e) => { - log::warn!("Invalid IP {}: {}", ip, e); - return Ok(None); - } - }; - - // Check Threat MMDB first (if configured) - log::info!("🔍 [Threat] Checking Threat MMDB for {}", ip); - if let Some(threat_data) = self.lookup_threat_mmdb(ip, ip_addr).await? { - log::info!("🔍 [Threat] Found threat data in Threat MMDB for {}: score={}", ip, threat_data.intel.score); - self.set_cache(ip, &threat_data).await; - return Ok(Some(threat_data)); - } - - // REST API disabled - skip directly to GeoIP fallback - log::info!("🔍 [Threat] No threat data found for {} in Threat MMDB, using GeoIP fallback", ip); - - // GeoIP fallback - let (geo, asn, org) = self.lookup_geo(ip_addr).await?; - let response = build_no_data_response(ip, ip_addr, geo, asn, org); - self.set_cache(ip, &response).await; - Ok(Some(response)) - } - - /// Get WAF fields for an IP address - pub async fn get_waf_fields(&self, ip: &str) -> Result> { - if let Some(threat_data) = self.get_threat_intel(ip).await? { - Ok(Some(WafFields::from(&threat_data))) - } else { - Ok(None) - } - } - - - /// Open the MMDB from the configured local path using memory-mapped file access - async fn refresh_threat_reader(&self) -> Result>> { - let mut path = self - .threat_mmdb_path - .clone() - .ok_or_else(|| anyhow!("Threat MMDB path not configured"))?; - - // If the path doesn't have a file extension, treat it as a directory and append the filename - if path.extension().is_none() { - path = path.join("threat.mmdb"); - } - - // Use spawn_blocking since file operations are blocking - let reader = tokio::task::spawn_blocking({ - let path = path.clone(); - move || -> Result> { - let file = File::open(&path) - .with_context(|| format!("Failed to open Threat MMDB file {:?}", path))?; - let mmap = unsafe { - MmapOptions::new() - .map(&file) - .with_context(|| format!("Failed to memory-map Threat MMDB from {:?}", path))? - }; - Reader::from_source(mmap) - .with_context(|| format!("Failed to parse Threat MMDB from {:?}", path)) - } - }) - .await - .context("Failed to spawn blocking task for MMDB open")??; - - let arc = Arc::new(reader); - - let mut guard = self.threat_reader.write().await; - *guard = Some(arc.clone()); - - log::info!("Threat MMDB opened (memory-mapped) from {:?}", path); - - Ok(arc) - } - - async fn refresh_geoip_country_reader(&self) -> Result>> { - let mut path = self - .geoip_country_path - .clone() - .ok_or_else(|| anyhow!("GeoIP Country MMDB path not configured"))?; - - // If the path doesn't have a file extension, treat it as a directory and append the filename - if path.extension().is_none() { - path = path.join("GeoLite2-Country.mmdb"); - } - - // Use spawn_blocking since file operations are blocking - let reader = tokio::task::spawn_blocking({ - let path = path.clone(); - move || -> Result> { - let file = File::open(&path) - .with_context(|| format!("Failed to open GeoIP Country MMDB file {:?}", path))?; - let mmap = unsafe { - MmapOptions::new() - .map(&file) - .with_context(|| format!("Failed to memory-map GeoIP Country MMDB from {:?}", path))? - }; - Reader::from_source(mmap) - .with_context(|| format!("Failed to parse GeoIP Country MMDB from {:?}", path)) - } - }) - .await - .context("Failed to spawn blocking task for MMDB open")??; - - let arc = Arc::new(reader); - - let mut guard = self.geoip_country_reader.write().await; - *guard = Some(arc.clone()); - - log::info!("GeoIP Country MMDB opened (memory-mapped) from {:?}", path); - - Ok(arc) - } - - async fn refresh_geoip_asn_reader(&self) -> Result>> { - let mut path = self - .geoip_asn_path - .clone() - .ok_or_else(|| anyhow!("GeoIP ASN MMDB path not configured"))?; - - // If the path doesn't have a file extension, treat it as a directory and append the filename - if path.extension().is_none() { - path = path.join("GeoLite2-ASN.mmdb"); - } - - // Use spawn_blocking since file operations are blocking - let reader = tokio::task::spawn_blocking({ - let path = path.clone(); - move || -> Result> { - let file = File::open(&path) - .with_context(|| format!("Failed to open GeoIP ASN MMDB file {:?}", path))?; - let mmap = unsafe { - MmapOptions::new() - .map(&file) - .with_context(|| format!("Failed to memory-map GeoIP ASN MMDB from {:?}", path))? - }; - Reader::from_source(mmap) - .with_context(|| format!("Failed to parse GeoIP ASN MMDB from {:?}", path)) - } - }) - .await - .context("Failed to spawn blocking task for MMDB open")??; - - let arc = Arc::new(reader); - - let mut guard = self.geoip_asn_reader.write().await; - *guard = Some(arc.clone()); - - log::info!("GeoIP ASN MMDB opened (memory-mapped) from {:?}", path); - - Ok(arc) - } - - async fn refresh_geoip_city_reader(&self) -> Result>> { - let mut path = self - .geoip_city_path - .clone() - .ok_or_else(|| anyhow!("GeoIP City MMDB path not configured"))?; - - // If the path doesn't have a file extension, treat it as a directory and append the filename - if path.extension().is_none() { - path = path.join("GeoLite2-City.mmdb"); - } - - // Use spawn_blocking since file operations are blocking - let reader = tokio::task::spawn_blocking({ - let path = path.clone(); - move || -> Result> { - let file = File::open(&path) - .with_context(|| format!("Failed to open GeoIP City MMDB file {:?}", path))?; - let mmap = unsafe { - MmapOptions::new() - .map(&file) - .with_context(|| format!("Failed to memory-map GeoIP City MMDB from {:?}", path))? - }; - Reader::from_source(mmap) - .with_context(|| format!("Failed to parse GeoIP City MMDB from {:?}", path)) - } - }) - .await - .context("Failed to spawn blocking task for MMDB open")??; - - let arc = Arc::new(reader); - - let mut guard = self.geoip_city_reader.write().await; - *guard = Some(arc.clone()); - - log::info!("GeoIP City MMDB opened (memory-mapped) from {:?}", path); - - Ok(arc) - } - - async fn ensure_geoip_country_reader(&self) -> Result>> { - { - let guard = self.geoip_country_reader.read().await; - if let Some(existing) = guard.as_ref() { - return Ok(existing.clone()); - } - } - self.refresh_geoip_country_reader().await - } - - async fn ensure_geoip_asn_reader(&self) -> Result>> { - { - let guard = self.geoip_asn_reader.read().await; - if let Some(existing) = guard.as_ref() { - return Ok(existing.clone()); - } - } - self.refresh_geoip_asn_reader().await - } - - async fn ensure_geoip_city_reader(&self) -> Result>> { - { - let guard = self.geoip_city_reader.read().await; - if let Some(existing) = guard.as_ref() { - return Ok(existing.clone()); - } - } - self.refresh_geoip_city_reader().await - } - - /// Look up threat intelligence from the Threat MMDB - async fn lookup_threat_mmdb(&self, ip: &str, ip_addr: IpAddr) -> Result> { - log::info!("🔍 [Threat MMDB] Starting lookup for {}", ip); - - // Check if threat reader is available - let reader_opt = { - let guard = self.threat_reader.read().await; - guard.clone() - }; - - let reader = match reader_opt { - Some(r) => { - log::info!("🔍 [Threat MMDB] Reader available, performing lookup"); - r - } - None => { - log::warn!("🔍 [Threat MMDB] Reader not loaded, skipping threat lookup for {}", ip); - return Ok(None); - } - }; - - // Perform blocking MMDB lookup in a separate thread - let result = tokio::task::spawn_blocking({ - let reader = reader.clone(); - let ip_addr_clone = ip_addr; - move || -> Result { - let lookup_result = reader.lookup(ip_addr_clone)?; - if !lookup_result.has_data() { - return Err(maxminddb::MaxMindDbError::invalid_input("IP address not found in database")); - } - lookup_result.decode()? - .ok_or_else(|| maxminddb::MaxMindDbError::invalid_input("Failed to decode threat data")) - } - }) - .await; - - match result { - Ok(Ok(threat_data)) => { - log::info!("🔍 [Threat MMDB] Found data for {}: {:?}", ip, threat_data); - Ok(Some(threat_data)) - } - Ok(Err(e)) => { - log::debug!("🔍 [Threat MMDB] IP {} not found or error: {}", ip, e); - Ok(None) - } - Err(e) => { - log::warn!("🔍 [Threat MMDB] Task error for {}: {}", ip, e); - Ok(None) - } - } - } - - async fn lookup_geo(&self, ip: IpAddr) -> Result<(GeoInfo, u32, String)> { - log::debug!("🔍 [GeoIP] Looking up IP: {}", ip); - - // ASN lookup (use ASN database if available, otherwise try country database) - let (asn_num, asn_org) = if let Ok(reader) = self.ensure_geoip_asn_reader().await { - tokio::task::spawn_blocking({ - let reader = reader.clone(); - let ip_clone = ip; - move || -> Result<(u32, String), MaxMindDbError> { - match reader.lookup(ip_clone) { - Ok(lookup_result) => { - if !lookup_result.has_data() { - log::warn!("🔍 [GeoIP] ASN Lookup: No data for {}", ip_clone); - return Ok((0, String::new())); - } - match lookup_result.decode::() { - Ok(Some(res)) => { - let asn = res.autonomous_system_number.unwrap_or(0); - let org = res.autonomous_system_organization.unwrap_or("").to_string(); - log::info!("🔍 [GeoIP] ASN Lookup Success for {}: ASN={}, Org='{}'", ip_clone, asn, org); - Ok((asn, org)) - } - Ok(None) => { - log::warn!("🔍 [GeoIP] ASN Lookup: No ASN data for {}", ip_clone); - Ok((0, String::new())) - } - Err(e) => { - log::warn!("🔍 [GeoIP] ASN Lookup FAILED for {}: {}", ip_clone, e); - Ok((0, String::new())) - } - } - } - Err(e) => { - log::warn!("🔍 [GeoIP] ASN Lookup FAILED for {}: {}", ip_clone, e); - Ok((0, String::new())) - } - } - } - }) - .await?? - } else { - (0, String::new()) - }; - - // Country lookup (use country database if available, fallback to city database) - let reader2 = if let Ok(reader) = self.ensure_geoip_country_reader().await { - reader - } else if let Ok(reader) = self.ensure_geoip_city_reader().await { - reader - } else { - return Err(anyhow!("No GeoIP database available for country lookup")); - }; - let (iso, country_name): (String, String) = tokio::task::spawn_blocking({ - let reader = reader2.clone(); - let ip_clone = ip; - move || -> Result<(String, String), MaxMindDbError> { - match reader.lookup(ip_clone) { - Ok(lookup_result) => { - if !lookup_result.has_data() { - log::warn!("🔍 [GeoIP] Country Lookup: No data for {}", ip_clone); - return Ok((String::new(), String::new())); - } - match lookup_result.decode::() { - Ok(Some(country)) => { - // Debug: Log the raw country data structure - log::info!("🔍 [GeoIP] RAW Country Data for {}: {:?}", ip_clone, country); - - // Try multiple fields - let iso = country - .country - .iso_code - .unwrap_or("") - .to_string(); - - // Also try registered_country if country is empty - let iso = if iso.is_empty() { - country - .registered_country - .iso_code - .unwrap_or("") - .to_string() - } else { - iso - }; - - let country_name = country - .country - .names - .english - .map(|s| s.to_string()) - .unwrap_or_default(); - - if iso.is_empty() { - log::warn!("🔍 [GeoIP] Country Lookup for {} returned EMPTY iso_code (IP found but no country data in ANY field)", ip_clone); - } else { - log::info!("🔍 [GeoIP] Country Lookup Success for {}: Country='{}' ({})", ip_clone, iso, country_name); - } - Ok((iso, country_name)) - } - Ok(None) => { - log::warn!("🔍 [GeoIP] Country Lookup: No country data for {}", ip_clone); - Ok((String::new(), String::new())) - } - Err(e) => { - log::warn!("🔍 [GeoIP] Country Lookup FAILED for {}: IP NOT FOUND in database - {}", ip_clone, e); - Ok((String::new(), String::new())) - } - } - } - Err(e) => { - log::warn!("🔍 [GeoIP] Country Lookup FAILED for {}: IP NOT FOUND in database - {}", ip_clone, e); - Ok((String::new(), String::new())) - } - } - } - }) - .await??; - - Ok(( - GeoInfo { - country: country_name, - iso_code: iso.clone(), - asn_iso_code: iso, - }, - asn_num, - asn_org, - )) - } - - /// Set data in the threat cache with TTL from record - async fn set_cache(&self, ip: &str, data: &ThreatResponse) { - let ttl = Duration::from_secs(data.ttl_s); - self.cache.put(ip, data.clone(), Some(ttl)); - } -} - -/// Global threat client instance -static THREAT_CLIENT: OnceCell> = OnceCell::const_new(); - -/// Initialize the global threat client with default cache size -pub async fn init_threat_client( - threat_mmdb_path: Option, - geoip_country_path: Option, - geoip_asn_path: Option, - geoip_city_path: Option, -) -> Result<()> { - init_threat_client_with_cache_size( - threat_mmdb_path, - geoip_country_path, - geoip_asn_path, - geoip_city_path, - DEFAULT_THREAT_CACHE_SIZE, - ).await -} - -/// Initialize the global threat client with configurable cache size -/// Use SMALL_THREAT_CACHE_SIZE for agent mode or low-memory environments -pub async fn init_threat_client_with_cache_size( - threat_mmdb_path: Option, - geoip_country_path: Option, - geoip_asn_path: Option, - geoip_city_path: Option, - cache_size: usize, -) -> Result<()> { - let client = Arc::new(ThreatClient::with_cache_size( - threat_mmdb_path, - geoip_country_path, - geoip_asn_path, - geoip_city_path, - cache_size, - )); - - log::info!("Initializing threat client with cache size: {} entries", cache_size); - - // Best-effort initial load from local MMDB paths. - // If the files are not present yet, the workers can download them later. - if let Err(e) = client.refresh_threat_reader().await { - log::warn!("Initial Threat MMDB load failed: {}", e); - } - - if let Err(e) = client.refresh_geoip_country_reader().await { - log::warn!("Initial GeoIP Country MMDB load failed: {}", e); - } - if let Err(e) = client.refresh_geoip_asn_reader().await { - log::warn!("Initial GeoIP ASN MMDB load failed: {}", e); - } - if let Err(e) = client.refresh_geoip_city_reader().await { - log::warn!("Initial GeoIP City MMDB load failed: {}", e); - } - - THREAT_CLIENT - .set(client) - .map_err(|_| anyhow::anyhow!("Failed to initialize threat client"))?; - - Ok(()) -} - -/// Trigger an immediate Threat MMDB refresh (used by worker) -pub async fn refresh_threat_mmdb() -> Result<()> { - let client = THREAT_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Threat client not initialized"))?; - - client.refresh_threat().await -} - -/// Trigger an immediate GeoIP MMDB refresh for all databases (used by worker) -pub async fn refresh_geoip_mmdb() -> Result<()> { - let client = THREAT_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Threat client not initialized"))?; - - client.refresh_geoip().await -} - -/// Trigger an immediate GeoIP Country MMDB refresh (used by worker) -pub async fn refresh_geoip_country_mmdb() -> Result<()> { - let client = THREAT_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Threat client not initialized"))?; - - client.refresh_geoip_country_reader().await.map(|_| ()) -} - -/// Trigger an immediate GeoIP ASN MMDB refresh (used by worker) -pub async fn refresh_geoip_asn_mmdb() -> Result<()> { - let client = THREAT_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Threat client not initialized"))?; - - client.refresh_geoip_asn_reader().await.map(|_| ()) -} - -/// Trigger an immediate GeoIP City MMDB refresh (used by worker) -pub async fn refresh_geoip_city_mmdb() -> Result<()> { - let client = THREAT_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Threat client not initialized"))?; - - client.refresh_geoip_city_reader().await.map(|_| ()) -} - -/// Get threat intelligence for an IP address -pub async fn get_threat_intel(ip: &str) -> Result> { - let client = match THREAT_CLIENT.get() { - Some(c) => c, - None => { - log::trace!("Threat client not initialized (API key not provided), skipping threat intel lookup for {}", ip); - return Ok(None); - } - }; - - client.get_threat_intel(ip).await -} - -/// Get access to a version cache (separate from threat response cache) -/// Uses the same pingora-memory-cache pattern as the threat response cache -pub fn get_version_cache() -> Result> { - use std::sync::OnceLock; - static VERSION_CACHE: OnceLock> = OnceLock::new(); - Ok(VERSION_CACHE.get_or_init(|| Arc::new(VersionCache::new(100))).clone()) -} - -/// Get WAF fields for an IP address -pub async fn get_waf_fields(ip: &str) -> Result> { - let client = match THREAT_CLIENT.get() { - Some(c) => c, - None => { - log::trace!("Threat client not initialized (API key not provided), skipping WAF fields lookup for {}", ip); - return Ok(None); - } - }; - - client.get_waf_fields(ip).await -} - - -fn build_no_data_response(ip: &str, ip_addr: IpAddr, geo: GeoInfo, asn: u32, org: String) -> ThreatResponse { - ThreatResponse { - schema_version: "1.0".to_string(), - tenant_id: "geoip".to_string(), - ip: ip.to_string(), - intel: ThreatIntel { - score: 0, - confidence: 0.0, - score_version: "geoip".to_string(), - categories: vec![], - tags: vec![], - first_seen: None, - last_seen: None, - source_count: 0, - reason_code: "NO_DATA".to_string(), - reason_summary: "No threat data available".to_string(), - rule_id: "none".to_string(), - }, - context: ThreatContext { - asn, - org, - ip_version: if ip_addr.is_ipv4() { 4 } else { 6 }, - geo, - }, - advice: "allow".to_string(), - ttl_s: 300, - generated_at: Utc::now(), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::net::IpAddr; - - #[test] - fn test_build_no_data_response_ipv4() { - let ip = "192.168.1.1"; - let ip_addr: IpAddr = ip.parse().unwrap(); - let geo = GeoInfo { - country: "United States".to_string(), - iso_code: "US".to_string(), - asn_iso_code: "US".to_string(), - }; - let response = build_no_data_response(ip, ip_addr, geo.clone(), 12345, "Test Org".to_string()); - - assert_eq!(response.ip, ip); - assert_eq!(response.context.ip_version, 4); - assert_eq!(response.context.asn, 12345); - assert_eq!(response.context.org, "Test Org"); - assert_eq!(response.context.geo.iso_code, "US"); - assert_eq!(response.intel.score, 0); - assert_eq!(response.advice, "allow"); - assert_eq!(response.intel.reason_code, "NO_DATA"); - } - - #[test] - fn test_build_no_data_response_ipv6() { - let ip = "2001:0db8:85a3:0000:0000:8a2e:0370:7334"; - let ip_addr: IpAddr = ip.parse().unwrap(); - let geo = GeoInfo { - country: "United States".to_string(), - iso_code: "US".to_string(), - asn_iso_code: "US".to_string(), - }; - let response = build_no_data_response(ip, ip_addr, geo.clone(), 67890, "Test Org 2".to_string()); - - assert_eq!(response.ip, ip); - assert_eq!(response.context.ip_version, 6); - assert_eq!(response.context.asn, 67890); - assert_eq!(response.context.org, "Test Org 2"); - } - - #[test] - fn test_waf_fields_from_threat_response() { - let threat = ThreatResponse { - schema_version: "1.0".to_string(), - tenant_id: "test".to_string(), - ip: "192.168.1.1".to_string(), - intel: ThreatIntel { - score: 75, - confidence: 0.85, - score_version: "1.0".to_string(), - categories: vec!["malware".to_string()], - tags: vec!["suspicious".to_string()], - first_seen: None, - last_seen: None, - source_count: 5, - reason_code: "THREAT_DETECTED".to_string(), - reason_summary: "Threat detected".to_string(), - rule_id: "rule1".to_string(), - }, - context: ThreatContext { - asn: 12345, - org: "Test Org".to_string(), - ip_version: 4, - geo: GeoInfo { - country: "United States".to_string(), - iso_code: "US".to_string(), - asn_iso_code: "US".to_string(), - }, - }, - advice: "block".to_string(), - ttl_s: 3600, - generated_at: Utc::now(), - }; - - let waf_fields = WafFields::from(&threat); - assert_eq!(waf_fields.ip_src_country, "US"); - assert_eq!(waf_fields.ip_src_asn, 12345); - assert_eq!(waf_fields.ip_src_asn_org, "Test Org"); - assert_eq!(waf_fields.ip_src_asn_country, "US"); - assert_eq!(waf_fields.threat_score, 75); - assert_eq!(waf_fields.threat_advice, "block"); - } - - #[test] - fn test_threat_client_new() { - let _client = ThreatClient::new(None, None, None, None); - // Just verify it can be created without panicking - assert!(true); - } - - #[tokio::test] - async fn test_threat_client_get_threat_intel_invalid_ip() { - let client = ThreatClient::new(None, None, None, None); - let result = client.get_threat_intel("invalid-ip").await; - assert!(result.is_ok()); - assert!(result.unwrap().is_none()); - } - - #[tokio::test] - async fn test_threat_client_get_waf_fields_no_client() { - let client = ThreatClient::new(None, None, None, None); - let result = client.get_waf_fields("192.168.1.1").await; - // Should return Ok(None) when no databases are configured or IP not found - // The method may return an error if IP parsing fails, so we just check it doesn't panic - let _ = result; - } - - #[test] - fn test_geo_info_serialization() { - let geo = GeoInfo { - country: "United States".to_string(), - iso_code: "US".to_string(), - asn_iso_code: "US".to_string(), - }; - let json = serde_json::to_string(&geo).unwrap(); - assert!(json.contains("US")); - assert!(json.contains("United States")); - } - - #[test] - fn test_threat_response_serialization() { - let response = ThreatResponse { - schema_version: "1.0".to_string(), - tenant_id: "test".to_string(), - ip: "192.168.1.1".to_string(), - intel: ThreatIntel { - score: 50, - confidence: 0.75, - score_version: "1.0".to_string(), - categories: vec!["test".to_string()], - tags: vec![], - first_seen: None, - last_seen: None, - source_count: 1, - reason_code: "TEST".to_string(), - reason_summary: "Test".to_string(), - rule_id: "test".to_string(), - }, - context: ThreatContext { - asn: 12345, - org: "Test".to_string(), - ip_version: 4, - geo: GeoInfo { - country: "US".to_string(), - iso_code: "US".to_string(), - asn_iso_code: "US".to_string(), - }, - }, - advice: "allow".to_string(), - ttl_s: 300, - generated_at: Utc::now(), - }; - - let json = serde_json::to_string(&response).unwrap(); - assert!(json.contains("192.168.1.1")); - assert!(json.contains("\"score\":50")); - assert!(json.contains("\"asn\":12345")); - } - - #[test] - fn test_maxminddb_error_type() { - // Verify that MaxMindDbError (not MaxMindDBError) is the correct type - let error = maxminddb::MaxMindDbError::invalid_input("test error"); - match error { - maxminddb::MaxMindDbError::InvalidInput { message } => { - assert_eq!(message, "test error"); - } - _ => panic!("Expected InvalidInput variant"), - } - } - - #[test] - fn test_maxminddb_error_invalid_database() { - let error = maxminddb::MaxMindDbError::invalid_database("corrupted database"); - match error { - maxminddb::MaxMindDbError::InvalidDatabase { message, .. } => { - assert_eq!(message, "corrupted database"); - } - _ => panic!("Expected InvalidDatabase variant"), - } - } - - #[test] - fn test_maxminddb_error_decoding() { - let error = maxminddb::MaxMindDbError::decoding("decoding failed"); - match error { - maxminddb::MaxMindDbError::Decoding { message, .. } => { - assert_eq!(message, "decoding failed"); - } - _ => panic!("Expected Decoding variant"), - } - } - - #[test] - fn test_geoip2_country_structure() { - // Test that we can construct a geoip2::Country structure - // This verifies the field access patterns we use in the code - use maxminddb::geoip2; - - // Create a default Country structure - let country_record = geoip2::Country::default(); - - // Verify the structure has the expected fields - // country.country and country.registered_country are direct structs (not Options) - let _iso_code = country_record.country.iso_code; - let _registered_iso = country_record.registered_country.iso_code; - let _english_name = country_record.country.names.english; - - // Just verify it compiles and the structure is correct - assert!(true); - } - - #[test] - fn test_geoip2_asn_structure() { - // Test that we can work with geoip2::Asn structure - // Note: Asn doesn't have Default, but we can verify the field types exist - // Verify the structure has the expected fields by checking the type - // autonomous_system_number: Option - // autonomous_system_organization: Option<&str> - // This test just documents the expected structure - assert!(true); - } - - #[test] - fn test_names_english_field() { - // Verify that Names has an 'english' field (not a get() method) - use maxminddb::geoip2; - - let names = geoip2::Names::default(); - - // Access the english field directly (not via get("en")) - let _english_name: Option<&str> = names.english; - - // Verify other language fields exist - let _german: Option<&str> = names.german; - let _french: Option<&str> = names.french; - let _spanish: Option<&str> = names.spanish; - - assert!(true); - } - - #[test] - fn test_country_iso_code_access() { - // Verify that country.iso_code is Option<&str>, not a nested structure - use maxminddb::geoip2; - - let country_info = geoip2::country::Country::default(); - - // iso_code should be Option<&str> - let _iso: Option<&str> = country_info.iso_code; - - // Verify the pattern: country.country.iso_code.unwrap_or("") - let country_record = geoip2::Country::default(); - let _iso_from_record: Option<&str> = country_record.country.iso_code; - let _registered_iso: Option<&str> = country_record.registered_country.iso_code; - - assert!(true); - } -} +use std::{net::IpAddr, path::PathBuf, sync::Arc, time::Duration}; + +use anyhow::{anyhow, Context, Result}; +use chrono::{DateTime, Utc}; +use maxminddb::{geoip2, MaxMindDbError, Reader}; +use memmap2::MmapOptions; +use std::fs::File; +#[cfg(feature = "proxy")] +use pingora_memory_cache::MemoryCache; +#[cfg(not(feature = "proxy"))] +use dashmap::DashMap; +#[cfg(not(feature = "proxy"))] +use std::time::Instant; +use serde::{Deserialize, Deserializer, Serialize}; +use tokio::sync::{OnceCell, RwLock}; + +/// Custom deserializer for optional datetime fields that can be empty strings or missing +fn deserialize_optional_datetime<'de, D>(deserializer: D) -> Result>, D::Error> +where + D: Deserializer<'de>, +{ + match Option::::deserialize(deserializer)? { + Some(s) => { + if s.is_empty() { + Ok(None) + } else { + DateTime::parse_from_rfc3339(&s) + .map(|dt| Some(dt.with_timezone(&Utc))) + .map_err(serde::de::Error::custom) + } + } + None => Ok(None), + } +} + +/// Threat intelligence response (REST shape) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreatResponse { + pub schema_version: String, + pub tenant_id: String, + pub ip: String, + pub intel: ThreatIntel, + pub context: ThreatContext, + pub advice: String, + pub ttl_s: u64, + pub generated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreatIntel { + pub score: u32, + pub confidence: f64, + pub score_version: String, + pub categories: Vec, + pub tags: Vec, + #[serde(deserialize_with = "deserialize_optional_datetime")] + pub first_seen: Option>, + #[serde(deserialize_with = "deserialize_optional_datetime")] + pub last_seen: Option>, + pub source_count: u32, + pub reason_code: String, + pub reason_summary: String, + pub rule_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreatContext { + pub asn: u32, + pub org: String, + pub ip_version: u8, + pub geo: GeoInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeoInfo { + pub country: String, + pub iso_code: String, + #[serde(rename = "asniso_code")] + pub asn_iso_code: String, +} + +/// WAF fields extracted from threat data +#[derive(Debug, Clone)] +pub struct WafFields { + pub ip_src_country: String, + pub ip_src_asn: u32, + pub ip_src_asn_org: String, + pub ip_src_asn_country: String, + pub threat_score: u32, + pub threat_advice: String, +} + +impl From<&ThreatResponse> for WafFields { + fn from(threat: &ThreatResponse) -> Self { + Self { + ip_src_country: threat.context.geo.iso_code.clone(), + ip_src_asn: threat.context.asn, + ip_src_asn_org: threat.context.org.clone(), + ip_src_asn_country: threat.context.geo.asn_iso_code.clone(), + threat_score: threat.intel.score, + threat_advice: threat.advice.clone(), + } + } +} + +#[derive(Clone, Copy)] +pub struct CacheStatus { + hit: bool, +} + +impl CacheStatus { + pub fn is_hit(&self) -> bool { + self.hit + } +} + +#[cfg(feature = "proxy")] +struct ThreatCache { + inner: MemoryCache, +} + +#[cfg(feature = "proxy")] +impl ThreatCache { + fn new(max_entries: usize) -> Self { + Self { + inner: MemoryCache::new(max_entries), + } + } + + fn get(&self, key: &str) -> (Option, CacheStatus) { + let (value, status) = self.inner.get(key); + (value, CacheStatus { hit: status.is_hit() }) + } + + fn put(&self, key: &str, value: ThreatResponse, ttl: Option) { + self.inner.put(key, value, ttl); + } +} + +#[cfg(not(feature = "proxy"))] +struct ThreatCacheEntry { + value: ThreatResponse, + expires_at: Instant, +} + +#[cfg(not(feature = "proxy"))] +struct ThreatCache { + entries: DashMap, + max_entries: usize, +} + +#[cfg(not(feature = "proxy"))] +impl ThreatCache { + fn new(max_entries: usize) -> Self { + Self { + entries: DashMap::new(), + max_entries, + } + } + + fn get(&self, key: &str) -> (Option, CacheStatus) { + if let Some(entry) = self.entries.get(key) { + if Instant::now() < entry.expires_at { + return (Some(entry.value.clone()), CacheStatus { hit: true }); + } + } + self.entries.remove(key); + (None, CacheStatus { hit: false }) + } + + fn put(&self, key: &str, value: ThreatResponse, ttl: Option) { + let ttl = match ttl { + Some(t) if !t.is_zero() => t, + _ => return, + }; + + if self.entries.len() >= self.max_entries { + if let Some(oldest_key) = self.entries.iter().next().map(|entry| entry.key().clone()) { + self.entries.remove(&oldest_key); + } + } + + let expires_at = Instant::now() + .checked_add(ttl) + .unwrap_or_else(Instant::now); + + self.entries.insert( + key.to_string(), + ThreatCacheEntry { + value, + expires_at, + }, + ); + } +} + +#[cfg(feature = "proxy")] +pub struct VersionCache { + inner: MemoryCache, +} + +#[cfg(feature = "proxy")] +impl VersionCache { + fn new(max_entries: usize) -> Self { + Self { + inner: MemoryCache::new(max_entries), + } + } + + pub fn get(&self, key: &str) -> (Option, CacheStatus) { + let (value, status) = self.inner.get(key); + (value, CacheStatus { hit: status.is_hit() }) + } + + pub fn put(&self, key: &str, value: String, ttl: Option) { + self.inner.put(key, value, ttl); + } +} + +#[cfg(not(feature = "proxy"))] +struct VersionCacheEntry { + value: String, + expires_at: Option, +} + +#[cfg(not(feature = "proxy"))] +pub struct VersionCache { + entries: DashMap, + max_entries: usize, +} + +#[cfg(not(feature = "proxy"))] +impl VersionCache { + fn new(max_entries: usize) -> Self { + Self { + entries: DashMap::new(), + max_entries, + } + } + + pub fn get(&self, key: &str) -> (Option, CacheStatus) { + if let Some(entry) = self.entries.get(key) { + if let Some(expires_at) = entry.expires_at { + if Instant::now() >= expires_at { + self.entries.remove(key); + return (None, CacheStatus { hit: false }); + } + } + return (Some(entry.value.clone()), CacheStatus { hit: true }); + } + (None, CacheStatus { hit: false }) + } + + pub fn put(&self, key: &str, value: String, ttl: Option) { + let expires_at = match ttl { + Some(t) if !t.is_zero() => Instant::now().checked_add(t), + Some(_) => return, + None => None, + }; + + if self.entries.len() >= self.max_entries { + if let Some(oldest_key) = self.entries.iter().next().map(|entry| entry.key().clone()) { + self.entries.remove(&oldest_key); + } + } + + self.entries.insert( + key.to_string(), + VersionCacheEntry { value, expires_at }, + ); + } +} + +/// Threat intel client: Threat MMDB first, then GeoIP MMDB fallback, with in-memory cache +pub struct ThreatClient { + threat_mmdb_path: Option, + geoip_country_path: Option, + geoip_asn_path: Option, + geoip_city_path: Option, + threat_reader: RwLock>>>, + geoip_country_reader: RwLock>>>, + geoip_asn_reader: RwLock>>>, + geoip_city_reader: RwLock>>>, + cache: Arc, +} + +/// Default cache size for threat response cache (10,000 entries) +pub const DEFAULT_THREAT_CACHE_SIZE: usize = 10_000; + +/// Smaller cache size suitable for agent mode or low-memory environments +pub const SMALL_THREAT_CACHE_SIZE: usize = 1_000; + +impl ThreatClient { + /// Create a new threat client with default cache size (10,000 entries) + pub fn new( + threat_mmdb_path: Option, + geoip_country_path: Option, + geoip_asn_path: Option, + geoip_city_path: Option, + ) -> Self { + Self::with_cache_size( + threat_mmdb_path, + geoip_country_path, + geoip_asn_path, + geoip_city_path, + DEFAULT_THREAT_CACHE_SIZE, + ) + } + + /// Create a new threat client with configurable cache size + /// Use SMALL_THREAT_CACHE_SIZE for agent mode or low-memory environments + pub fn with_cache_size( + threat_mmdb_path: Option, + geoip_country_path: Option, + geoip_asn_path: Option, + geoip_city_path: Option, + cache_size: usize, + ) -> Self { + Self { + threat_mmdb_path, + geoip_country_path, + geoip_asn_path, + geoip_city_path, + threat_reader: RwLock::new(None), + geoip_country_reader: RwLock::new(None), + geoip_asn_reader: RwLock::new(None), + geoip_city_reader: RwLock::new(None), + cache: Arc::new(ThreatCache::new(cache_size)), + } + } + + pub async fn refresh_threat(&self) -> Result<()> { + self.refresh_threat_reader().await.map(|_| ()) + } + + pub async fn refresh_geoip(&self) -> Result<()> { + // Refresh all geoip databases + if let Err(e) = self.refresh_geoip_country_reader().await { + log::warn!("Failed to refresh GeoIP Country database: {}", e); + } + if let Err(e) = self.refresh_geoip_asn_reader().await { + log::warn!("Failed to refresh GeoIP ASN database: {}", e); + } + if let Err(e) = self.refresh_geoip_city_reader().await { + log::warn!("Failed to refresh GeoIP City database: {}", e); + } + Ok(()) + } + + /// Get threat intelligence for an IP address with caching + /// Priority: Cache → Threat MMDB → GeoIP MMDB (REST API disabled) + pub async fn get_threat_intel(&self, ip: &str) -> Result> { + // L1 cache + let (cached, status) = self.cache.get(ip); + if let Some(data) = cached { + if status.is_hit() { + log::debug!("Threat data for {} found in cache", ip); + return Ok(Some(data)); + } + } + + let ip_addr: IpAddr = match ip.parse() { + Ok(v) => v, + Err(e) => { + log::warn!("Invalid IP {}: {}", ip, e); + return Ok(None); + } + }; + + // Check Threat MMDB first (if configured) + log::info!("🔍 [Threat] Checking Threat MMDB for {}", ip); + if let Some(threat_data) = self.lookup_threat_mmdb(ip, ip_addr).await? { + log::info!("🔍 [Threat] Found threat data in Threat MMDB for {}: score={}", ip, threat_data.intel.score); + self.set_cache(ip, &threat_data).await; + return Ok(Some(threat_data)); + } + + // REST API disabled - skip directly to GeoIP fallback + log::info!("🔍 [Threat] No threat data found for {} in Threat MMDB, using GeoIP fallback", ip); + + // GeoIP fallback + let (geo, asn, org) = self.lookup_geo(ip_addr).await?; + let response = build_no_data_response(ip, ip_addr, geo, asn, org); + self.set_cache(ip, &response).await; + Ok(Some(response)) + } + + /// Get WAF fields for an IP address + pub async fn get_waf_fields(&self, ip: &str) -> Result> { + if let Some(threat_data) = self.get_threat_intel(ip).await? { + Ok(Some(WafFields::from(&threat_data))) + } else { + Ok(None) + } + } + + + /// Open the MMDB from the configured local path using memory-mapped file access + async fn refresh_threat_reader(&self) -> Result>> { + let mut path = self + .threat_mmdb_path + .clone() + .ok_or_else(|| anyhow!("Threat MMDB path not configured"))?; + + // If the path doesn't have a file extension, treat it as a directory and append the filename + if path.extension().is_none() { + path = path.join("threat.mmdb"); + } + + // Use spawn_blocking since file operations are blocking + let reader = tokio::task::spawn_blocking({ + let path = path.clone(); + move || -> Result> { + let file = File::open(&path) + .with_context(|| format!("Failed to open Threat MMDB file {:?}", path))?; + let mmap = unsafe { + MmapOptions::new() + .map(&file) + .with_context(|| format!("Failed to memory-map Threat MMDB from {:?}", path))? + }; + Reader::from_source(mmap) + .with_context(|| format!("Failed to parse Threat MMDB from {:?}", path)) + } + }) + .await + .context("Failed to spawn blocking task for MMDB open")??; + + let arc = Arc::new(reader); + + let mut guard = self.threat_reader.write().await; + *guard = Some(arc.clone()); + + log::info!("Threat MMDB opened (memory-mapped) from {:?}", path); + + Ok(arc) + } + + async fn refresh_geoip_country_reader(&self) -> Result>> { + let mut path = self + .geoip_country_path + .clone() + .ok_or_else(|| anyhow!("GeoIP Country MMDB path not configured"))?; + + // If the path doesn't have a file extension, treat it as a directory and append the filename + if path.extension().is_none() { + path = path.join("GeoLite2-Country.mmdb"); + } + + // Use spawn_blocking since file operations are blocking + let reader = tokio::task::spawn_blocking({ + let path = path.clone(); + move || -> Result> { + let file = File::open(&path) + .with_context(|| format!("Failed to open GeoIP Country MMDB file {:?}", path))?; + let mmap = unsafe { + MmapOptions::new() + .map(&file) + .with_context(|| format!("Failed to memory-map GeoIP Country MMDB from {:?}", path))? + }; + Reader::from_source(mmap) + .with_context(|| format!("Failed to parse GeoIP Country MMDB from {:?}", path)) + } + }) + .await + .context("Failed to spawn blocking task for MMDB open")??; + + let arc = Arc::new(reader); + + let mut guard = self.geoip_country_reader.write().await; + *guard = Some(arc.clone()); + + log::info!("GeoIP Country MMDB opened (memory-mapped) from {:?}", path); + + Ok(arc) + } + + async fn refresh_geoip_asn_reader(&self) -> Result>> { + let mut path = self + .geoip_asn_path + .clone() + .ok_or_else(|| anyhow!("GeoIP ASN MMDB path not configured"))?; + + // If the path doesn't have a file extension, treat it as a directory and append the filename + if path.extension().is_none() { + path = path.join("GeoLite2-ASN.mmdb"); + } + + // Use spawn_blocking since file operations are blocking + let reader = tokio::task::spawn_blocking({ + let path = path.clone(); + move || -> Result> { + let file = File::open(&path) + .with_context(|| format!("Failed to open GeoIP ASN MMDB file {:?}", path))?; + let mmap = unsafe { + MmapOptions::new() + .map(&file) + .with_context(|| format!("Failed to memory-map GeoIP ASN MMDB from {:?}", path))? + }; + Reader::from_source(mmap) + .with_context(|| format!("Failed to parse GeoIP ASN MMDB from {:?}", path)) + } + }) + .await + .context("Failed to spawn blocking task for MMDB open")??; + + let arc = Arc::new(reader); + + let mut guard = self.geoip_asn_reader.write().await; + *guard = Some(arc.clone()); + + log::info!("GeoIP ASN MMDB opened (memory-mapped) from {:?}", path); + + Ok(arc) + } + + async fn refresh_geoip_city_reader(&self) -> Result>> { + let mut path = self + .geoip_city_path + .clone() + .ok_or_else(|| anyhow!("GeoIP City MMDB path not configured"))?; + + // If the path doesn't have a file extension, treat it as a directory and append the filename + if path.extension().is_none() { + path = path.join("GeoLite2-City.mmdb"); + } + + // Use spawn_blocking since file operations are blocking + let reader = tokio::task::spawn_blocking({ + let path = path.clone(); + move || -> Result> { + let file = File::open(&path) + .with_context(|| format!("Failed to open GeoIP City MMDB file {:?}", path))?; + let mmap = unsafe { + MmapOptions::new() + .map(&file) + .with_context(|| format!("Failed to memory-map GeoIP City MMDB from {:?}", path))? + }; + Reader::from_source(mmap) + .with_context(|| format!("Failed to parse GeoIP City MMDB from {:?}", path)) + } + }) + .await + .context("Failed to spawn blocking task for MMDB open")??; + + let arc = Arc::new(reader); + + let mut guard = self.geoip_city_reader.write().await; + *guard = Some(arc.clone()); + + log::info!("GeoIP City MMDB opened (memory-mapped) from {:?}", path); + + Ok(arc) + } + + async fn ensure_geoip_country_reader(&self) -> Result>> { + { + let guard = self.geoip_country_reader.read().await; + if let Some(existing) = guard.as_ref() { + return Ok(existing.clone()); + } + } + self.refresh_geoip_country_reader().await + } + + async fn ensure_geoip_asn_reader(&self) -> Result>> { + { + let guard = self.geoip_asn_reader.read().await; + if let Some(existing) = guard.as_ref() { + return Ok(existing.clone()); + } + } + self.refresh_geoip_asn_reader().await + } + + async fn ensure_geoip_city_reader(&self) -> Result>> { + { + let guard = self.geoip_city_reader.read().await; + if let Some(existing) = guard.as_ref() { + return Ok(existing.clone()); + } + } + self.refresh_geoip_city_reader().await + } + + /// Look up threat intelligence from the Threat MMDB + async fn lookup_threat_mmdb(&self, ip: &str, ip_addr: IpAddr) -> Result> { + log::info!("🔍 [Threat MMDB] Starting lookup for {}", ip); + + // Check if threat reader is available + let reader_opt = { + let guard = self.threat_reader.read().await; + guard.clone() + }; + + let reader = match reader_opt { + Some(r) => { + log::info!("🔍 [Threat MMDB] Reader available, performing lookup"); + r + } + None => { + log::warn!("🔍 [Threat MMDB] Reader not loaded, skipping threat lookup for {}", ip); + return Ok(None); + } + }; + + // Perform blocking MMDB lookup in a separate thread + let result = tokio::task::spawn_blocking({ + let reader = reader.clone(); + let ip_addr_clone = ip_addr; + move || -> Result { + let lookup_result = reader.lookup(ip_addr_clone)?; + if !lookup_result.has_data() { + return Err(maxminddb::MaxMindDbError::invalid_input("IP address not found in database")); + } + lookup_result.decode()? + .ok_or_else(|| maxminddb::MaxMindDbError::invalid_input("Failed to decode threat data")) + } + }) + .await; + + match result { + Ok(Ok(threat_data)) => { + log::info!("🔍 [Threat MMDB] Found data for {}: {:?}", ip, threat_data); + Ok(Some(threat_data)) + } + Ok(Err(e)) => { + log::debug!("🔍 [Threat MMDB] IP {} not found or error: {}", ip, e); + Ok(None) + } + Err(e) => { + log::warn!("🔍 [Threat MMDB] Task error for {}: {}", ip, e); + Ok(None) + } + } + } + + async fn lookup_geo(&self, ip: IpAddr) -> Result<(GeoInfo, u32, String)> { + log::debug!("🔍 [GeoIP] Looking up IP: {}", ip); + + // ASN lookup (use ASN database if available, otherwise try country database) + let (asn_num, asn_org) = if let Ok(reader) = self.ensure_geoip_asn_reader().await { + tokio::task::spawn_blocking({ + let reader = reader.clone(); + let ip_clone = ip; + move || -> Result<(u32, String), MaxMindDbError> { + match reader.lookup(ip_clone) { + Ok(lookup_result) => { + if !lookup_result.has_data() { + log::warn!("🔍 [GeoIP] ASN Lookup: No data for {}", ip_clone); + return Ok((0, String::new())); + } + match lookup_result.decode::() { + Ok(Some(res)) => { + let asn = res.autonomous_system_number.unwrap_or(0); + let org = res.autonomous_system_organization.unwrap_or("").to_string(); + log::info!("🔍 [GeoIP] ASN Lookup Success for {}: ASN={}, Org='{}'", ip_clone, asn, org); + Ok((asn, org)) + } + Ok(None) => { + log::warn!("🔍 [GeoIP] ASN Lookup: No ASN data for {}", ip_clone); + Ok((0, String::new())) + } + Err(e) => { + log::warn!("🔍 [GeoIP] ASN Lookup FAILED for {}: {}", ip_clone, e); + Ok((0, String::new())) + } + } + } + Err(e) => { + log::warn!("🔍 [GeoIP] ASN Lookup FAILED for {}: {}", ip_clone, e); + Ok((0, String::new())) + } + } + } + }) + .await?? + } else { + (0, String::new()) + }; + + // Country lookup (use country database if available, fallback to city database) + let reader2 = if let Ok(reader) = self.ensure_geoip_country_reader().await { + reader + } else if let Ok(reader) = self.ensure_geoip_city_reader().await { + reader + } else { + return Err(anyhow!("No GeoIP database available for country lookup")); + }; + let (iso, country_name): (String, String) = tokio::task::spawn_blocking({ + let reader = reader2.clone(); + let ip_clone = ip; + move || -> Result<(String, String), MaxMindDbError> { + match reader.lookup(ip_clone) { + Ok(lookup_result) => { + if !lookup_result.has_data() { + log::warn!("🔍 [GeoIP] Country Lookup: No data for {}", ip_clone); + return Ok((String::new(), String::new())); + } + match lookup_result.decode::() { + Ok(Some(country)) => { + // Debug: Log the raw country data structure + log::info!("🔍 [GeoIP] RAW Country Data for {}: {:?}", ip_clone, country); + + // Try multiple fields + let iso = country + .country + .iso_code + .unwrap_or("") + .to_string(); + + // Also try registered_country if country is empty + let iso = if iso.is_empty() { + country + .registered_country + .iso_code + .unwrap_or("") + .to_string() + } else { + iso + }; + + let country_name = country + .country + .names + .english + .map(|s| s.to_string()) + .unwrap_or_default(); + + if iso.is_empty() { + log::warn!("🔍 [GeoIP] Country Lookup for {} returned EMPTY iso_code (IP found but no country data in ANY field)", ip_clone); + } else { + log::info!("🔍 [GeoIP] Country Lookup Success for {}: Country='{}' ({})", ip_clone, iso, country_name); + } + Ok((iso, country_name)) + } + Ok(None) => { + log::warn!("🔍 [GeoIP] Country Lookup: No country data for {}", ip_clone); + Ok((String::new(), String::new())) + } + Err(e) => { + log::warn!("🔍 [GeoIP] Country Lookup FAILED for {}: IP NOT FOUND in database - {}", ip_clone, e); + Ok((String::new(), String::new())) + } + } + } + Err(e) => { + log::warn!("🔍 [GeoIP] Country Lookup FAILED for {}: IP NOT FOUND in database - {}", ip_clone, e); + Ok((String::new(), String::new())) + } + } + } + }) + .await??; + + Ok(( + GeoInfo { + country: country_name, + iso_code: iso.clone(), + asn_iso_code: iso, + }, + asn_num, + asn_org, + )) + } + + /// Set data in the threat cache with TTL from record + async fn set_cache(&self, ip: &str, data: &ThreatResponse) { + let ttl = Duration::from_secs(data.ttl_s); + self.cache.put(ip, data.clone(), Some(ttl)); + } +} + +/// Global threat client instance +static THREAT_CLIENT: OnceCell> = OnceCell::const_new(); + +/// Initialize the global threat client with default cache size +pub async fn init_threat_client( + threat_mmdb_path: Option, + geoip_country_path: Option, + geoip_asn_path: Option, + geoip_city_path: Option, +) -> Result<()> { + init_threat_client_with_cache_size( + threat_mmdb_path, + geoip_country_path, + geoip_asn_path, + geoip_city_path, + DEFAULT_THREAT_CACHE_SIZE, + ).await +} + +/// Initialize the global threat client with configurable cache size +/// Use SMALL_THREAT_CACHE_SIZE for agent mode or low-memory environments +pub async fn init_threat_client_with_cache_size( + threat_mmdb_path: Option, + geoip_country_path: Option, + geoip_asn_path: Option, + geoip_city_path: Option, + cache_size: usize, +) -> Result<()> { + let client = Arc::new(ThreatClient::with_cache_size( + threat_mmdb_path, + geoip_country_path, + geoip_asn_path, + geoip_city_path, + cache_size, + )); + + log::info!("Initializing threat client with cache size: {} entries", cache_size); + + // Best-effort initial load from local MMDB paths. + // If the files are not present yet, the workers can download them later. + if let Err(e) = client.refresh_threat_reader().await { + log::warn!("Initial Threat MMDB load failed: {}", e); + } + + if let Err(e) = client.refresh_geoip_country_reader().await { + log::warn!("Initial GeoIP Country MMDB load failed: {}", e); + } + if let Err(e) = client.refresh_geoip_asn_reader().await { + log::warn!("Initial GeoIP ASN MMDB load failed: {}", e); + } + if let Err(e) = client.refresh_geoip_city_reader().await { + log::warn!("Initial GeoIP City MMDB load failed: {}", e); + } + + THREAT_CLIENT + .set(client) + .map_err(|_| anyhow::anyhow!("Failed to initialize threat client"))?; + + Ok(()) +} + +/// Trigger an immediate Threat MMDB refresh (used by worker) +pub async fn refresh_threat_mmdb() -> Result<()> { + let client = THREAT_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Threat client not initialized"))?; + + client.refresh_threat().await +} + +/// Trigger an immediate GeoIP MMDB refresh for all databases (used by worker) +pub async fn refresh_geoip_mmdb() -> Result<()> { + let client = THREAT_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Threat client not initialized"))?; + + client.refresh_geoip().await +} + +/// Trigger an immediate GeoIP Country MMDB refresh (used by worker) +pub async fn refresh_geoip_country_mmdb() -> Result<()> { + let client = THREAT_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Threat client not initialized"))?; + + client.refresh_geoip_country_reader().await.map(|_| ()) +} + +/// Trigger an immediate GeoIP ASN MMDB refresh (used by worker) +pub async fn refresh_geoip_asn_mmdb() -> Result<()> { + let client = THREAT_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Threat client not initialized"))?; + + client.refresh_geoip_asn_reader().await.map(|_| ()) +} + +/// Trigger an immediate GeoIP City MMDB refresh (used by worker) +pub async fn refresh_geoip_city_mmdb() -> Result<()> { + let client = THREAT_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Threat client not initialized"))?; + + client.refresh_geoip_city_reader().await.map(|_| ()) +} + +/// Get threat intelligence for an IP address +pub async fn get_threat_intel(ip: &str) -> Result> { + let client = match THREAT_CLIENT.get() { + Some(c) => c, + None => { + log::trace!("Threat client not initialized (API key not provided), skipping threat intel lookup for {}", ip); + return Ok(None); + } + }; + + client.get_threat_intel(ip).await +} + +/// Get access to a version cache (separate from threat response cache) +/// Uses the same pingora-memory-cache pattern as the threat response cache +pub fn get_version_cache() -> Result> { + use std::sync::OnceLock; + static VERSION_CACHE: OnceLock> = OnceLock::new(); + Ok(VERSION_CACHE.get_or_init(|| Arc::new(VersionCache::new(100))).clone()) +} + +/// Get WAF fields for an IP address +pub async fn get_waf_fields(ip: &str) -> Result> { + let client = match THREAT_CLIENT.get() { + Some(c) => c, + None => { + log::trace!("Threat client not initialized (API key not provided), skipping WAF fields lookup for {}", ip); + return Ok(None); + } + }; + + client.get_waf_fields(ip).await +} + + +fn build_no_data_response(ip: &str, ip_addr: IpAddr, geo: GeoInfo, asn: u32, org: String) -> ThreatResponse { + ThreatResponse { + schema_version: "1.0".to_string(), + tenant_id: "geoip".to_string(), + ip: ip.to_string(), + intel: ThreatIntel { + score: 0, + confidence: 0.0, + score_version: "geoip".to_string(), + categories: vec![], + tags: vec![], + first_seen: None, + last_seen: None, + source_count: 0, + reason_code: "NO_DATA".to_string(), + reason_summary: "No threat data available".to_string(), + rule_id: "none".to_string(), + }, + context: ThreatContext { + asn, + org, + ip_version: if ip_addr.is_ipv4() { 4 } else { 6 }, + geo, + }, + advice: "allow".to_string(), + ttl_s: 300, + generated_at: Utc::now(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::IpAddr; + + #[test] + fn test_build_no_data_response_ipv4() { + let ip = "192.168.1.1"; + let ip_addr: IpAddr = ip.parse().unwrap(); + let geo = GeoInfo { + country: "United States".to_string(), + iso_code: "US".to_string(), + asn_iso_code: "US".to_string(), + }; + let response = build_no_data_response(ip, ip_addr, geo.clone(), 12345, "Test Org".to_string()); + + assert_eq!(response.ip, ip); + assert_eq!(response.context.ip_version, 4); + assert_eq!(response.context.asn, 12345); + assert_eq!(response.context.org, "Test Org"); + assert_eq!(response.context.geo.iso_code, "US"); + assert_eq!(response.intel.score, 0); + assert_eq!(response.advice, "allow"); + assert_eq!(response.intel.reason_code, "NO_DATA"); + } + + #[test] + fn test_build_no_data_response_ipv6() { + let ip = "2001:0db8:85a3:0000:0000:8a2e:0370:7334"; + let ip_addr: IpAddr = ip.parse().unwrap(); + let geo = GeoInfo { + country: "United States".to_string(), + iso_code: "US".to_string(), + asn_iso_code: "US".to_string(), + }; + let response = build_no_data_response(ip, ip_addr, geo.clone(), 67890, "Test Org 2".to_string()); + + assert_eq!(response.ip, ip); + assert_eq!(response.context.ip_version, 6); + assert_eq!(response.context.asn, 67890); + assert_eq!(response.context.org, "Test Org 2"); + } + + #[test] + fn test_waf_fields_from_threat_response() { + let threat = ThreatResponse { + schema_version: "1.0".to_string(), + tenant_id: "test".to_string(), + ip: "192.168.1.1".to_string(), + intel: ThreatIntel { + score: 75, + confidence: 0.85, + score_version: "1.0".to_string(), + categories: vec!["malware".to_string()], + tags: vec!["suspicious".to_string()], + first_seen: None, + last_seen: None, + source_count: 5, + reason_code: "THREAT_DETECTED".to_string(), + reason_summary: "Threat detected".to_string(), + rule_id: "rule1".to_string(), + }, + context: ThreatContext { + asn: 12345, + org: "Test Org".to_string(), + ip_version: 4, + geo: GeoInfo { + country: "United States".to_string(), + iso_code: "US".to_string(), + asn_iso_code: "US".to_string(), + }, + }, + advice: "block".to_string(), + ttl_s: 3600, + generated_at: Utc::now(), + }; + + let waf_fields = WafFields::from(&threat); + assert_eq!(waf_fields.ip_src_country, "US"); + assert_eq!(waf_fields.ip_src_asn, 12345); + assert_eq!(waf_fields.ip_src_asn_org, "Test Org"); + assert_eq!(waf_fields.ip_src_asn_country, "US"); + assert_eq!(waf_fields.threat_score, 75); + assert_eq!(waf_fields.threat_advice, "block"); + } + + #[test] + fn test_threat_client_new() { + let _client = ThreatClient::new(None, None, None, None); + // Just verify it can be created without panicking + assert!(true); + } + + #[tokio::test] + async fn test_threat_client_get_threat_intel_invalid_ip() { + let client = ThreatClient::new(None, None, None, None); + let result = client.get_threat_intel("invalid-ip").await; + assert!(result.is_ok()); + assert!(result.unwrap().is_none()); + } + + #[tokio::test] + async fn test_threat_client_get_waf_fields_no_client() { + let client = ThreatClient::new(None, None, None, None); + let result = client.get_waf_fields("192.168.1.1").await; + // Should return Ok(None) when no databases are configured or IP not found + // The method may return an error if IP parsing fails, so we just check it doesn't panic + let _ = result; + } + + #[test] + fn test_geo_info_serialization() { + let geo = GeoInfo { + country: "United States".to_string(), + iso_code: "US".to_string(), + asn_iso_code: "US".to_string(), + }; + let json = serde_json::to_string(&geo).unwrap(); + assert!(json.contains("US")); + assert!(json.contains("United States")); + } + + #[test] + fn test_threat_response_serialization() { + let response = ThreatResponse { + schema_version: "1.0".to_string(), + tenant_id: "test".to_string(), + ip: "192.168.1.1".to_string(), + intel: ThreatIntel { + score: 50, + confidence: 0.75, + score_version: "1.0".to_string(), + categories: vec!["test".to_string()], + tags: vec![], + first_seen: None, + last_seen: None, + source_count: 1, + reason_code: "TEST".to_string(), + reason_summary: "Test".to_string(), + rule_id: "test".to_string(), + }, + context: ThreatContext { + asn: 12345, + org: "Test".to_string(), + ip_version: 4, + geo: GeoInfo { + country: "US".to_string(), + iso_code: "US".to_string(), + asn_iso_code: "US".to_string(), + }, + }, + advice: "allow".to_string(), + ttl_s: 300, + generated_at: Utc::now(), + }; + + let json = serde_json::to_string(&response).unwrap(); + assert!(json.contains("192.168.1.1")); + assert!(json.contains("\"score\":50")); + assert!(json.contains("\"asn\":12345")); + } + + #[test] + fn test_maxminddb_error_type() { + // Verify that MaxMindDbError (not MaxMindDBError) is the correct type + let error = maxminddb::MaxMindDbError::invalid_input("test error"); + match error { + maxminddb::MaxMindDbError::InvalidInput { message } => { + assert_eq!(message, "test error"); + } + _ => panic!("Expected InvalidInput variant"), + } + } + + #[test] + fn test_maxminddb_error_invalid_database() { + let error = maxminddb::MaxMindDbError::invalid_database("corrupted database"); + match error { + maxminddb::MaxMindDbError::InvalidDatabase { message, .. } => { + assert_eq!(message, "corrupted database"); + } + _ => panic!("Expected InvalidDatabase variant"), + } + } + + #[test] + fn test_maxminddb_error_decoding() { + let error = maxminddb::MaxMindDbError::decoding("decoding failed"); + match error { + maxminddb::MaxMindDbError::Decoding { message, .. } => { + assert_eq!(message, "decoding failed"); + } + _ => panic!("Expected Decoding variant"), + } + } + + #[test] + fn test_geoip2_country_structure() { + // Test that we can construct a geoip2::Country structure + // This verifies the field access patterns we use in the code + use maxminddb::geoip2; + + // Create a default Country structure + let country_record = geoip2::Country::default(); + + // Verify the structure has the expected fields + // country.country and country.registered_country are direct structs (not Options) + let _iso_code = country_record.country.iso_code; + let _registered_iso = country_record.registered_country.iso_code; + let _english_name = country_record.country.names.english; + + // Just verify it compiles and the structure is correct + assert!(true); + } + + #[test] + fn test_geoip2_asn_structure() { + // Test that we can work with geoip2::Asn structure + // Note: Asn doesn't have Default, but we can verify the field types exist + // Verify the structure has the expected fields by checking the type + // autonomous_system_number: Option + // autonomous_system_organization: Option<&str> + // This test just documents the expected structure + assert!(true); + } + + #[test] + fn test_names_english_field() { + // Verify that Names has an 'english' field (not a get() method) + use maxminddb::geoip2; + + let names = geoip2::Names::default(); + + // Access the english field directly (not via get("en")) + let _english_name: Option<&str> = names.english; + + // Verify other language fields exist + let _german: Option<&str> = names.german; + let _french: Option<&str> = names.french; + let _spanish: Option<&str> = names.spanish; + + assert!(true); + } + + #[test] + fn test_country_iso_code_access() { + // Verify that country.iso_code is Option<&str>, not a nested structure + use maxminddb::geoip2; + + let country_info = geoip2::country::Country::default(); + + // iso_code should be Option<&str> + let _iso: Option<&str> = country_info.iso_code; + + // Verify the pattern: country.country.iso_code.unwrap_or("") + let country_record = geoip2::Country::default(); + let _iso_from_record: Option<&str> = country_record.country.iso_code; + let _registered_iso: Option<&str> = country_record.registered_country.iso_code; + + assert!(true); + } +} diff --git a/src/utils.rs b/src/utils.rs index b0d5d95..9b52921 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,31 +1,31 @@ -#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] -pub mod bpf_utils; -#[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] -#[path = "utils/bpf_utils_noop.rs"] -pub mod bpf_utils; -#[cfg(feature = "proxy")] -pub mod discovery; -#[cfg(feature = "proxy")] -mod filewatch; -#[cfg(feature = "proxy")] -pub mod healthcheck; -pub mod http_utils; -#[cfg(feature = "proxy")] -pub mod metrics; -#[cfg(feature = "proxy")] -pub mod parceyaml; -#[cfg(feature = "proxy")] -pub mod state; -pub mod structs; -#[cfg(feature = "proxy")] -pub mod tls; -#[cfg(feature = "proxy")] -pub mod tls_client_hello; -pub mod tls_fingerprint; -#[cfg(feature = "proxy")] -pub mod tools; -#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] -pub mod tcp_fingerprint; -#[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] -#[path = "utils/tcp_fingerprint_noop.rs"] -pub mod tcp_fingerprint; +#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] +pub mod bpf_utils; +#[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] +#[path = "utils/bpf_utils_noop.rs"] +pub mod bpf_utils; +#[cfg(feature = "proxy")] +pub mod discovery; +#[cfg(feature = "proxy")] +mod filewatch; +#[cfg(feature = "proxy")] +pub mod healthcheck; +pub mod http_utils; +#[cfg(feature = "proxy")] +pub mod metrics; +#[cfg(feature = "proxy")] +pub mod parceyaml; +#[cfg(feature = "proxy")] +pub mod state; +pub mod structs; +#[cfg(feature = "proxy")] +pub mod tls; +#[cfg(feature = "proxy")] +pub mod tls_client_hello; +pub mod tls_fingerprint; +#[cfg(feature = "proxy")] +pub mod tools; +#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] +pub mod tcp_fingerprint; +#[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] +#[path = "utils/tcp_fingerprint_noop.rs"] +pub mod tcp_fingerprint; diff --git a/src/utils/bpf_utils.rs b/src/utils/bpf_utils.rs index 32b2afd..caa5bf4 100644 --- a/src/utils/bpf_utils.rs +++ b/src/utils/bpf_utils.rs @@ -1,225 +1,225 @@ -use std::net::{Ipv4Addr, Ipv6Addr}; -use std::os::fd::AsFd; -use std::fs; - -use crate::bpf::{self, FilterSkel}; -use libbpf_rs::{Xdp, XdpFlags}; -use nix::libc; - -/// XDP attachment mode -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum XdpMode { - Hardware, - Driver, - DriverReplace, - Skb, - SkbReplace, - SkbIpv6Enabled, - SkbUpdateIfNoexist, -} - -impl XdpMode { - pub fn as_str(&self) -> &'static str { - match self { - XdpMode::Hardware => "hardware", - XdpMode::Driver => "driver", - XdpMode::DriverReplace => "driver_replace", - XdpMode::Skb => "skb", - XdpMode::SkbReplace => "skb_replace", - XdpMode::SkbIpv6Enabled => "skb_ipv6_enabled", - XdpMode::SkbUpdateIfNoexist => "skb_update_if_noexist", - } - } -} - -fn is_ipv6_disabled(iface: Option<&str>) -> bool { - // Check if IPv6 is disabled for a specific interface or system-wide - if let Some(iface_name) = iface { - if let Ok(content) = fs::read_to_string(format!("/proc/sys/net/ipv6/conf/{}/disable_ipv6", iface_name)) { - return content.trim() == "1"; - } - } - // Fall back to system-wide check - if let Ok(content) = fs::read_to_string("/proc/sys/net/ipv6/conf/all/disable_ipv6") { - return content.trim() == "1"; - } - false -} - -fn try_enable_ipv6_for_interface(iface: &str) -> Result<(), Box> { - // Try to enable IPv6 only for the specific interface (not system-wide) - // This allows IPv4-only operation elsewhere while enabling XDP on this interface - let disable_path = format!("/proc/sys/net/ipv6/conf/{}/disable_ipv6", iface); - - if is_ipv6_disabled(Some(iface)) { - log::debug!("IPv6 is disabled for interface {}, attempting to enable it for XDP attachment", iface); - std::fs::write(&disable_path, "0")?; - log::info!("Enabled IPv6 for interface {} (required for XDP, IPv4-only elsewhere)", iface); - Ok(()) - } else { - Ok(()) - } -} - -pub fn bpf_attach_to_xdp( - skel: &mut FilterSkel<'_>, - ifindex: i32, - iface_name: Option<&str>, - ip_version: &str, -) -> Result> { - // Try hardware mode first, fall back to driver mode if not supported - let xdp = Xdp::new(skel.progs.arxignis_xdp_filter.as_fd().into()); - - // Try hardware offload mode first - if let Ok(()) = xdp.attach(ifindex, XdpFlags::HW_MODE) { - return Ok(XdpMode::Hardware); - } - - // Fall back to driver mode if hardware mode fails - match xdp.attach(ifindex, XdpFlags::DRV_MODE) { - Ok(()) => { - return Ok(XdpMode::Driver); - } - Err(e) => { - // Check if error is EEXIST (error 17) - XDP program already attached - let error_msg = e.to_string(); - if error_msg.contains("17") || error_msg.contains("File exists") { - log::debug!("Driver mode failed: XDP program already attached, trying to replace with REPLACE flag"); - // Try to replace existing XDP program - match xdp.attach(ifindex, XdpFlags::DRV_MODE | XdpFlags::REPLACE) { - Ok(()) => { - return Ok(XdpMode::DriverReplace); - } - Err(e2) => { - log::debug!("Replace in driver mode failed: {}, trying generic SKB mode", e2); - } - } - } else { - log::debug!("Driver mode failed, trying generic SKB mode: {}", e); - } - } - } - - // Try SKB mode (should work on all interfaces, including IPv4-only) - match xdp.attach(ifindex, XdpFlags::SKB_MODE) { - Ok(()) => { - return Ok(XdpMode::Skb); - } - Err(e) => { - // Check if error is EEXIST (error 17) first - let error_msg = e.to_string(); - if error_msg.contains("17") || error_msg.contains("File exists") { - log::debug!("SKB mode failed: XDP program already attached, trying to replace"); - // Try to replace existing XDP program in SKB mode - match xdp.attach(ifindex, XdpFlags::SKB_MODE | XdpFlags::REPLACE) { - Ok(()) => { - return Ok(XdpMode::SkbReplace); - } - Err(e2) => { - log::debug!("Replace in SKB mode failed: {}, continuing with other fallbacks", e2); - } - } - } - // If SKB mode fails with EAFNOSUPPORT (error 97), it's likely due to IPv6 being disabled - if error_msg.contains("97") || error_msg.contains("Address family not supported") { - log::debug!("SKB mode failed with EAFNOSUPPORT, IPv6 might be disabled"); - - // Note: XDP requires IPv6 to be enabled at the kernel level for attachment, - // even when processing only IPv4 packets. This is a kernel limitation. - // For IPv4-only mode, we can enable IPv6 just for this interface (not system-wide) - // which allows XDP to attach while still operating in IPv4-only mode. - if ip_version == "ipv4" { - log::debug!("IPv4-only mode: Attempting to enable IPv6 on interface for XDP attachment (kernel requirement)"); - } - - // Try to enable IPv6 only for this specific interface (not system-wide) - // This allows IPv4-only operation elsewhere while enabling XDP on this interface - if let Some(iface) = iface_name { - if try_enable_ipv6_for_interface(iface).is_ok() { - log::debug!("Retrying XDP attachment after enabling IPv6 for interface {}", iface); - - // Retry SKB mode after enabling IPv6 for the interface - match xdp.attach(ifindex, XdpFlags::SKB_MODE) { - Ok(()) => { - return Ok(XdpMode::SkbIpv6Enabled); - } - Err(e2) => { - log::debug!("SKB mode still failed after enabling IPv6 for interface: {}", e2); - } - } - } else { - log::debug!("Failed to enable IPv6 for interface {} or no permission", iface); - } - } else { - log::debug!("Interface name not provided, cannot enable IPv6 per-interface"); - } - - // Try with UPDATE_IF_NOEXIST flag as last resort - match xdp.attach(ifindex, XdpFlags::SKB_MODE | XdpFlags::UPDATE_IF_NOEXIST) { - Ok(()) => { - return Ok(XdpMode::SkbUpdateIfNoexist); - } - Err(e2) => { - log::debug!("SKB mode with UPDATE_IF_NOEXIST also failed: {}", e2); - } - } - } - - Err(Box::new(e)) - } - } -} - -pub fn ipv4_to_u32_be(ip: Ipv4Addr) -> u32 { - u32::from_be_bytes(ip.octets()) -} - -pub fn convert_ip_into_bpf_map_key_bytes(ip: Ipv4Addr, prefixlen: u32) -> Box<[u8]> { - let ip_u32: u32 = ip.into(); - let ip_be = ip_u32.to_be(); - - let my_ip_key: bpf::types::lpm_key = bpf::types::lpm_key { - prefixlen, - addr: ip_be, - }; - - let my_ip_key_bytes = unsafe { plain::as_bytes(&my_ip_key) }; - my_ip_key_bytes.to_vec().into_boxed_slice() -} - -pub fn convert_ipv6_into_bpf_map_key_bytes(ip: Ipv6Addr, prefixlen: u32) -> Box<[u8]> { - let ip_bytes = ip.octets(); - - let my_ip_key: bpf::types::lpm_key_v6 = bpf::types::lpm_key_v6 { - prefixlen, - addr: ip_bytes, - }; - - let my_ip_key_bytes = unsafe { plain::as_bytes(&my_ip_key) }; - my_ip_key_bytes.to_vec().into_boxed_slice() -} - -pub fn bpf_detach_from_xdp(ifindex: i32) -> Result<(), Box> { - // Create a dummy XDP instance for detaching - // We need to query first to get the existing program ID - let dummy_fd = unsafe { libc::open("/dev/null\0".as_ptr() as *const libc::c_char, libc::O_RDONLY) }; - if dummy_fd < 0 { - return Err("Failed to create dummy file descriptor".into()); - } - - let xdp = Xdp::new(unsafe { std::os::fd::BorrowedFd::borrow_raw(dummy_fd) }); - - // Try to detach using different modes - let modes = [XdpFlags::HW_MODE, XdpFlags::DRV_MODE, XdpFlags::SKB_MODE]; - - for mode in modes { - if let Ok(()) = xdp.detach(ifindex, mode) { - log::info!("XDP program detached from interface"); - unsafe { libc::close(dummy_fd); } - return Ok(()); - } - } - - unsafe { libc::close(dummy_fd); } - Err("Failed to detach XDP program from interface".into()) -} +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::os::fd::AsFd; +use std::fs; + +use crate::bpf::{self, FilterSkel}; +use libbpf_rs::{Xdp, XdpFlags}; +use nix::libc; + +/// XDP attachment mode +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum XdpMode { + Hardware, + Driver, + DriverReplace, + Skb, + SkbReplace, + SkbIpv6Enabled, + SkbUpdateIfNoexist, +} + +impl XdpMode { + pub fn as_str(&self) -> &'static str { + match self { + XdpMode::Hardware => "hardware", + XdpMode::Driver => "driver", + XdpMode::DriverReplace => "driver_replace", + XdpMode::Skb => "skb", + XdpMode::SkbReplace => "skb_replace", + XdpMode::SkbIpv6Enabled => "skb_ipv6_enabled", + XdpMode::SkbUpdateIfNoexist => "skb_update_if_noexist", + } + } +} + +fn is_ipv6_disabled(iface: Option<&str>) -> bool { + // Check if IPv6 is disabled for a specific interface or system-wide + if let Some(iface_name) = iface { + if let Ok(content) = fs::read_to_string(format!("/proc/sys/net/ipv6/conf/{}/disable_ipv6", iface_name)) { + return content.trim() == "1"; + } + } + // Fall back to system-wide check + if let Ok(content) = fs::read_to_string("/proc/sys/net/ipv6/conf/all/disable_ipv6") { + return content.trim() == "1"; + } + false +} + +fn try_enable_ipv6_for_interface(iface: &str) -> Result<(), Box> { + // Try to enable IPv6 only for the specific interface (not system-wide) + // This allows IPv4-only operation elsewhere while enabling XDP on this interface + let disable_path = format!("/proc/sys/net/ipv6/conf/{}/disable_ipv6", iface); + + if is_ipv6_disabled(Some(iface)) { + log::debug!("IPv6 is disabled for interface {}, attempting to enable it for XDP attachment", iface); + std::fs::write(&disable_path, "0")?; + log::info!("Enabled IPv6 for interface {} (required for XDP, IPv4-only elsewhere)", iface); + Ok(()) + } else { + Ok(()) + } +} + +pub fn bpf_attach_to_xdp( + skel: &mut FilterSkel<'_>, + ifindex: i32, + iface_name: Option<&str>, + ip_version: &str, +) -> Result> { + // Try hardware mode first, fall back to driver mode if not supported + let xdp = Xdp::new(skel.progs.arxignis_xdp_filter.as_fd().into()); + + // Try hardware offload mode first + if let Ok(()) = xdp.attach(ifindex, XdpFlags::HW_MODE) { + return Ok(XdpMode::Hardware); + } + + // Fall back to driver mode if hardware mode fails + match xdp.attach(ifindex, XdpFlags::DRV_MODE) { + Ok(()) => { + return Ok(XdpMode::Driver); + } + Err(e) => { + // Check if error is EEXIST (error 17) - XDP program already attached + let error_msg = e.to_string(); + if error_msg.contains("17") || error_msg.contains("File exists") { + log::debug!("Driver mode failed: XDP program already attached, trying to replace with REPLACE flag"); + // Try to replace existing XDP program + match xdp.attach(ifindex, XdpFlags::DRV_MODE | XdpFlags::REPLACE) { + Ok(()) => { + return Ok(XdpMode::DriverReplace); + } + Err(e2) => { + log::debug!("Replace in driver mode failed: {}, trying generic SKB mode", e2); + } + } + } else { + log::debug!("Driver mode failed, trying generic SKB mode: {}", e); + } + } + } + + // Try SKB mode (should work on all interfaces, including IPv4-only) + match xdp.attach(ifindex, XdpFlags::SKB_MODE) { + Ok(()) => { + return Ok(XdpMode::Skb); + } + Err(e) => { + // Check if error is EEXIST (error 17) first + let error_msg = e.to_string(); + if error_msg.contains("17") || error_msg.contains("File exists") { + log::debug!("SKB mode failed: XDP program already attached, trying to replace"); + // Try to replace existing XDP program in SKB mode + match xdp.attach(ifindex, XdpFlags::SKB_MODE | XdpFlags::REPLACE) { + Ok(()) => { + return Ok(XdpMode::SkbReplace); + } + Err(e2) => { + log::debug!("Replace in SKB mode failed: {}, continuing with other fallbacks", e2); + } + } + } + // If SKB mode fails with EAFNOSUPPORT (error 97), it's likely due to IPv6 being disabled + if error_msg.contains("97") || error_msg.contains("Address family not supported") { + log::debug!("SKB mode failed with EAFNOSUPPORT, IPv6 might be disabled"); + + // Note: XDP requires IPv6 to be enabled at the kernel level for attachment, + // even when processing only IPv4 packets. This is a kernel limitation. + // For IPv4-only mode, we can enable IPv6 just for this interface (not system-wide) + // which allows XDP to attach while still operating in IPv4-only mode. + if ip_version == "ipv4" { + log::debug!("IPv4-only mode: Attempting to enable IPv6 on interface for XDP attachment (kernel requirement)"); + } + + // Try to enable IPv6 only for this specific interface (not system-wide) + // This allows IPv4-only operation elsewhere while enabling XDP on this interface + if let Some(iface) = iface_name { + if try_enable_ipv6_for_interface(iface).is_ok() { + log::debug!("Retrying XDP attachment after enabling IPv6 for interface {}", iface); + + // Retry SKB mode after enabling IPv6 for the interface + match xdp.attach(ifindex, XdpFlags::SKB_MODE) { + Ok(()) => { + return Ok(XdpMode::SkbIpv6Enabled); + } + Err(e2) => { + log::debug!("SKB mode still failed after enabling IPv6 for interface: {}", e2); + } + } + } else { + log::debug!("Failed to enable IPv6 for interface {} or no permission", iface); + } + } else { + log::debug!("Interface name not provided, cannot enable IPv6 per-interface"); + } + + // Try with UPDATE_IF_NOEXIST flag as last resort + match xdp.attach(ifindex, XdpFlags::SKB_MODE | XdpFlags::UPDATE_IF_NOEXIST) { + Ok(()) => { + return Ok(XdpMode::SkbUpdateIfNoexist); + } + Err(e2) => { + log::debug!("SKB mode with UPDATE_IF_NOEXIST also failed: {}", e2); + } + } + } + + Err(Box::new(e)) + } + } +} + +pub fn ipv4_to_u32_be(ip: Ipv4Addr) -> u32 { + u32::from_be_bytes(ip.octets()) +} + +pub fn convert_ip_into_bpf_map_key_bytes(ip: Ipv4Addr, prefixlen: u32) -> Box<[u8]> { + let ip_u32: u32 = ip.into(); + let ip_be = ip_u32.to_be(); + + let my_ip_key: bpf::types::lpm_key = bpf::types::lpm_key { + prefixlen, + addr: ip_be, + }; + + let my_ip_key_bytes = unsafe { plain::as_bytes(&my_ip_key) }; + my_ip_key_bytes.to_vec().into_boxed_slice() +} + +pub fn convert_ipv6_into_bpf_map_key_bytes(ip: Ipv6Addr, prefixlen: u32) -> Box<[u8]> { + let ip_bytes = ip.octets(); + + let my_ip_key: bpf::types::lpm_key_v6 = bpf::types::lpm_key_v6 { + prefixlen, + addr: ip_bytes, + }; + + let my_ip_key_bytes = unsafe { plain::as_bytes(&my_ip_key) }; + my_ip_key_bytes.to_vec().into_boxed_slice() +} + +pub fn bpf_detach_from_xdp(ifindex: i32) -> Result<(), Box> { + // Create a dummy XDP instance for detaching + // We need to query first to get the existing program ID + let dummy_fd = unsafe { libc::open("/dev/null\0".as_ptr() as *const libc::c_char, libc::O_RDONLY) }; + if dummy_fd < 0 { + return Err("Failed to create dummy file descriptor".into()); + } + + let xdp = Xdp::new(unsafe { std::os::fd::BorrowedFd::borrow_raw(dummy_fd) }); + + // Try to detach using different modes + let modes = [XdpFlags::HW_MODE, XdpFlags::DRV_MODE, XdpFlags::SKB_MODE]; + + for mode in modes { + if let Ok(()) = xdp.detach(ifindex, mode) { + log::info!("XDP program detached from interface"); + unsafe { libc::close(dummy_fd); } + return Ok(()); + } + } + + unsafe { libc::close(dummy_fd); } + Err("Failed to detach XDP program from interface".into()) +} diff --git a/src/utils/bpf_utils_noop.rs b/src/utils/bpf_utils_noop.rs index ba4d863..c87956f 100644 --- a/src/utils/bpf_utils_noop.rs +++ b/src/utils/bpf_utils_noop.rs @@ -1,16 +1,16 @@ -use std::error::Error; - -use crate::bpf::FilterSkel; - -pub fn bpf_attach_to_xdp( - _skel: &mut FilterSkel<'_>, - _ifindex: i32, - _iface_name: Option<&str>, - _ip_version: &str, -) -> Result<(), Box> { - Err("BPF support disabled at build time".into()) -} - -pub fn bpf_detach_from_xdp(_ifindex: i32) -> Result<(), Box> { - Ok(()) -} +use std::error::Error; + +use crate::bpf::FilterSkel; + +pub fn bpf_attach_to_xdp( + _skel: &mut FilterSkel<'_>, + _ifindex: i32, + _iface_name: Option<&str>, + _ip_version: &str, +) -> Result<(), Box> { + Err("BPF support disabled at build time".into()) +} + +pub fn bpf_detach_from_xdp(_ifindex: i32) -> Result<(), Box> { + Ok(()) +} diff --git a/src/utils/discovery.rs b/src/utils/discovery.rs index 0705c71..5654d7b 100644 --- a/src/utils/discovery.rs +++ b/src/utils/discovery.rs @@ -1,48 +1,48 @@ -use crate::utils::filewatch; -use crate::utils::structs::Configuration; -use crate::http_proxy::webserver; -use async_trait::async_trait; -use futures::channel::mpsc::Sender; -use std::sync::Arc; - -pub struct APIUpstreamProvider { - pub config_api_enabled: bool, - pub address: String, - pub masterkey: String, - pub tls_address: Option, - pub tls_certificate: Option, - pub tls_key_file: Option, - pub file_server_address: Option, - pub file_server_folder: Option, -} - -#[async_trait] -impl Discovery for APIUpstreamProvider { - async fn start(&self, toreturn: Sender) { - webserver::run_server(self, toreturn).await; - } -} - -pub struct FromFileProvider { - pub path: String, -} - -pub struct ConsulProvider { - pub config: Arc, -} - -pub struct KubernetesProvider { - pub config: Arc, -} - -#[async_trait] -pub trait Discovery { - async fn start(&self, tx: Sender); -} - -#[async_trait] -impl Discovery for FromFileProvider { - async fn start(&self, tx: Sender) { - tokio::spawn(filewatch::start(self.path.clone(), tx.clone())); - } -} +use crate::utils::filewatch; +use crate::utils::structs::Configuration; +use crate::http_proxy::webserver; +use async_trait::async_trait; +use futures::channel::mpsc::Sender; +use std::sync::Arc; + +pub struct APIUpstreamProvider { + pub config_api_enabled: bool, + pub address: String, + pub masterkey: String, + pub tls_address: Option, + pub tls_certificate: Option, + pub tls_key_file: Option, + pub file_server_address: Option, + pub file_server_folder: Option, +} + +#[async_trait] +impl Discovery for APIUpstreamProvider { + async fn start(&self, toreturn: Sender) { + webserver::run_server(self, toreturn).await; + } +} + +pub struct FromFileProvider { + pub path: String, +} + +pub struct ConsulProvider { + pub config: Arc, +} + +pub struct KubernetesProvider { + pub config: Arc, +} + +#[async_trait] +pub trait Discovery { + async fn start(&self, tx: Sender); +} + +#[async_trait] +impl Discovery for FromFileProvider { + async fn start(&self, tx: Sender) { + tokio::spawn(filewatch::start(self.path.clone(), tx.clone())); + } +} diff --git a/src/utils/filewatch.rs b/src/utils/filewatch.rs index 3d4d456..e8946b3 100644 --- a/src/utils/filewatch.rs +++ b/src/utils/filewatch.rs @@ -1,59 +1,59 @@ -use crate::utils::parceyaml::load_configuration; -use crate::utils::structs::Configuration; -use futures::channel::mpsc::Sender; -use futures::SinkExt; -use log::error; -use notify::event::ModifyKind; -use notify::{Config, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; -use tokio::time::sleep; -use std::path::Path; -use std::time::{Duration, Instant}; -use tokio::task; - -pub async fn start(fp: String, mut toreturn: Sender) { - sleep(Duration::from_millis(50)).await; // For having nice logs :-) - let file_path = fp.as_str(); - let parent_dir = Path::new(file_path).parent().unwrap(); - let (local_tx, mut local_rx) = tokio::sync::mpsc::channel::>(1); - - let _watcher_handle = task::spawn_blocking({ - let parent_dir = parent_dir.to_path_buf(); // Move directory path into the closure - move || { - let mut watcher = RecommendedWatcher::new( - move |res| { - let _ = local_tx.blocking_send(res); - }, - Config::default(), - ) - .unwrap(); - watcher.watch(&parent_dir, RecursiveMode::Recursive).unwrap(); - let (_rtx, mut rrx) = tokio::sync::mpsc::channel::(1); - let _ = rrx.blocking_recv(); - } - }); - let mut start = Instant::now(); - - while let Some(event) = local_rx.recv().await { - match event { - Ok(e) => match e.kind { - EventKind::Modify(ModifyKind::Data(_)) | EventKind::Create(..) | EventKind::Remove(..) => { - if e.paths[0].to_str().unwrap().ends_with("yaml") { - if start.elapsed() > Duration::from_secs(2) { - start = Instant::now(); - // info!("Config File changed :=> {:?}", e); - let snd = load_configuration(file_path, "filepath").await; - match snd { - Some(snd) => { - toreturn.send(snd).await.unwrap(); - } - None => {} - } - } - } - } - _ => (), - }, - Err(e) => error!("Watch error: {:?}", e), - } - } -} +use crate::utils::parceyaml::load_configuration; +use crate::utils::structs::Configuration; +use futures::channel::mpsc::Sender; +use futures::SinkExt; +use log::error; +use notify::event::ModifyKind; +use notify::{Config, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; +use tokio::time::sleep; +use std::path::Path; +use std::time::{Duration, Instant}; +use tokio::task; + +pub async fn start(fp: String, mut toreturn: Sender) { + sleep(Duration::from_millis(50)).await; // For having nice logs :-) + let file_path = fp.as_str(); + let parent_dir = Path::new(file_path).parent().unwrap(); + let (local_tx, mut local_rx) = tokio::sync::mpsc::channel::>(1); + + let _watcher_handle = task::spawn_blocking({ + let parent_dir = parent_dir.to_path_buf(); // Move directory path into the closure + move || { + let mut watcher = RecommendedWatcher::new( + move |res| { + let _ = local_tx.blocking_send(res); + }, + Config::default(), + ) + .unwrap(); + watcher.watch(&parent_dir, RecursiveMode::Recursive).unwrap(); + let (_rtx, mut rrx) = tokio::sync::mpsc::channel::(1); + let _ = rrx.blocking_recv(); + } + }); + let mut start = Instant::now(); + + while let Some(event) = local_rx.recv().await { + match event { + Ok(e) => match e.kind { + EventKind::Modify(ModifyKind::Data(_)) | EventKind::Create(..) | EventKind::Remove(..) => { + if e.paths[0].to_str().unwrap().ends_with("yaml") { + if start.elapsed() > Duration::from_secs(2) { + start = Instant::now(); + // info!("Config File changed :=> {:?}", e); + let snd = load_configuration(file_path, "filepath").await; + match snd { + Some(snd) => { + toreturn.send(snd).await.unwrap(); + } + None => {} + } + } + } + } + _ => (), + }, + Err(e) => error!("Watch error: {:?}", e), + } + } +} diff --git a/src/utils/healthcheck.rs b/src/utils/healthcheck.rs index 8419569..e53af51 100644 --- a/src/utils/healthcheck.rs +++ b/src/utils/healthcheck.rs @@ -1,161 +1,161 @@ -use crate::utils::structs::{InnerMap, UpstreamsDashMap, UpstreamsIdMap}; -use crate::utils::tools::*; -use dashmap::DashMap; -use log::{error, warn}; -use reqwest::{Client, Version}; -use std::sync::atomic::AtomicUsize; -use std::sync::Arc; -use std::time::Duration; -use tokio::time::interval; -use tonic::transport::Endpoint; - -pub async fn hc2(upslist: Arc, fullist: Arc, idlist: Arc, params: (&str, u64)) { - let mut period = interval(Duration::from_secs(params.1)); - let client = Client::builder().timeout(Duration::from_secs(params.1)).danger_accept_invalid_certs(true).build().unwrap(); - loop { - tokio::select! { - _ = period.tick() => { - populate_upstreams(&upslist, &fullist, &idlist, params, &client).await; - } - } - } -} - -pub async fn populate_upstreams(upslist: &Arc, fullist: &Arc, idlist: &Arc, params: (&str, u64), client: &Client) { - let totest = build_upstreams(fullist, params.0, client).await; - if !compare_dashmaps(&totest, upslist) { - clone_dashmap_into(&totest, upslist); - clone_idmap_into(&totest, idlist); - } -} - -pub async fn initiate_upstreams(fullist: UpstreamsDashMap) -> UpstreamsDashMap { - let client = Client::builder().timeout(Duration::from_secs(2)).danger_accept_invalid_certs(true).build().unwrap(); - build_upstreams(&fullist, "HEAD", &client).await -} - -async fn build_upstreams(fullist: &UpstreamsDashMap, method: &str, client: &Client) -> UpstreamsDashMap { - let totest: UpstreamsDashMap = DashMap::new(); - let fclone = clone_dashmap(fullist); - for val in fclone.iter() { - let host = val.key(); - let inner = DashMap::new(); - - for path_entry in val.value().iter() { - let path = path_entry.key(); - let mut innervec = Vec::new(); - - for (_, upstream) in path_entry.value().0.iter().enumerate() { - let tls = detect_tls(upstream.address.as_str(), &upstream.port, &client).await; - let is_h2 = matches!(tls.1, Some(Version::HTTP_2)); - - let link = if tls.0 { - format!("https://{}:{}{}", upstream.address, upstream.port, path) - } else { - format!("http://{}:{}{}", upstream.address, upstream.port, path) - }; - - let mut scheme = InnerMap { - address: upstream.address.clone(), - port: upstream.port, - ssl_enabled: tls.0, - http2_enabled: is_h2, - https_proxy_enabled: upstream.https_proxy_enabled, - rate_limit: upstream.rate_limit, - healthcheck: upstream.healthcheck, - disable_access_log: upstream.disable_access_log, - }; - - if scheme.healthcheck.unwrap_or(true) { - let resp = http_request(&link, method, "", &client).await; - if resp.0 { - if resp.1 { - scheme.http2_enabled = is_h2; // could be adjusted further - } - innervec.push(scheme); - } else { - warn!("Dead Upstream : {}", link); - } - } else { - innervec.push(scheme); - } - - // let resp = http_request(&link, method, "", &client).await; - // if resp.0 { - // if resp.1 { - // scheme.is_http2 = is_h2; // could be adjusted further - // } - // innervec.push(scheme); - // } else { - // warn!("Dead Upstream : {}", link); - // } - } - inner.insert(path.clone(), (innervec, AtomicUsize::new(0))); - } - totest.insert(host.clone(), inner); - } - totest -} - -async fn http_request(url: &str, method: &str, payload: &str, client: &Client) -> (bool, bool) { - if !["POST", "GET", "HEAD"].contains(&method) { - error!("Method {} not supported. Only GET|POST|HEAD are supported ", method); - return (false, false); - } - async fn send_request(client: &Client, method: &str, url: &str, payload: &str) -> Option { - match method { - "POST" => client.post(url).body(payload.to_owned()).send().await.ok(), - "GET" => client.get(url).send().await.ok(), - "HEAD" => client.head(url).send().await.ok(), - _ => None, - } - } - - match send_request(&client, method, url, payload).await { - Some(response) => { - let status = response.status().as_u16(); - ((99..499).contains(&status), false) - } - None => (ping_grpc(&url).await, true), - } -} - -pub async fn ping_grpc(addr: &str) -> bool { - let endpoint_result = Endpoint::from_shared(addr.to_owned()); - - if let Ok(endpoint) = endpoint_result { - let endpoint = endpoint.timeout(Duration::from_secs(2)); - - match tokio::time::timeout(Duration::from_secs(3), endpoint.connect()).await { - Ok(Ok(_channel)) => true, - _ => false, - } - } else { - false - } -} - -async fn detect_tls(ip: &str, port: &u16, client: &Client) -> (bool, Option) { - let https_url = format!("https://{}:{}", ip, port); - match client.get(&https_url).send().await { - Ok(response) => { - // println!("{} => {:?} (HTTPS)", https_url, response.version()); - return (true, Some(response.version())); - } - _ => {} - } - let http_url = format!("http://{}:{}", ip, port); - match client.get(&http_url).send().await { - Ok(response) => { - // println!("{} => {:?} (HTTP)", http_url, response.version()); - (false, Some(response.version())) - } - Err(_) => { - if ping_grpc(&http_url).await { - (false, Some(Version::HTTP_2)) - } else { - (false, None) - } - } - } -} +use crate::utils::structs::{InnerMap, UpstreamsDashMap, UpstreamsIdMap}; +use crate::utils::tools::*; +use dashmap::DashMap; +use log::{error, warn}; +use reqwest::{Client, Version}; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::interval; +use tonic::transport::Endpoint; + +pub async fn hc2(upslist: Arc, fullist: Arc, idlist: Arc, params: (&str, u64)) { + let mut period = interval(Duration::from_secs(params.1)); + let client = Client::builder().timeout(Duration::from_secs(params.1)).danger_accept_invalid_certs(true).build().unwrap(); + loop { + tokio::select! { + _ = period.tick() => { + populate_upstreams(&upslist, &fullist, &idlist, params, &client).await; + } + } + } +} + +pub async fn populate_upstreams(upslist: &Arc, fullist: &Arc, idlist: &Arc, params: (&str, u64), client: &Client) { + let totest = build_upstreams(fullist, params.0, client).await; + if !compare_dashmaps(&totest, upslist) { + clone_dashmap_into(&totest, upslist); + clone_idmap_into(&totest, idlist); + } +} + +pub async fn initiate_upstreams(fullist: UpstreamsDashMap) -> UpstreamsDashMap { + let client = Client::builder().timeout(Duration::from_secs(2)).danger_accept_invalid_certs(true).build().unwrap(); + build_upstreams(&fullist, "HEAD", &client).await +} + +async fn build_upstreams(fullist: &UpstreamsDashMap, method: &str, client: &Client) -> UpstreamsDashMap { + let totest: UpstreamsDashMap = DashMap::new(); + let fclone = clone_dashmap(fullist); + for val in fclone.iter() { + let host = val.key(); + let inner = DashMap::new(); + + for path_entry in val.value().iter() { + let path = path_entry.key(); + let mut innervec = Vec::new(); + + for (_, upstream) in path_entry.value().0.iter().enumerate() { + let tls = detect_tls(upstream.address.as_str(), &upstream.port, &client).await; + let is_h2 = matches!(tls.1, Some(Version::HTTP_2)); + + let link = if tls.0 { + format!("https://{}:{}{}", upstream.address, upstream.port, path) + } else { + format!("http://{}:{}{}", upstream.address, upstream.port, path) + }; + + let mut scheme = InnerMap { + address: upstream.address.clone(), + port: upstream.port, + ssl_enabled: tls.0, + http2_enabled: is_h2, + https_proxy_enabled: upstream.https_proxy_enabled, + rate_limit: upstream.rate_limit, + healthcheck: upstream.healthcheck, + disable_access_log: upstream.disable_access_log, + }; + + if scheme.healthcheck.unwrap_or(true) { + let resp = http_request(&link, method, "", &client).await; + if resp.0 { + if resp.1 { + scheme.http2_enabled = is_h2; // could be adjusted further + } + innervec.push(scheme); + } else { + warn!("Dead Upstream : {}", link); + } + } else { + innervec.push(scheme); + } + + // let resp = http_request(&link, method, "", &client).await; + // if resp.0 { + // if resp.1 { + // scheme.is_http2 = is_h2; // could be adjusted further + // } + // innervec.push(scheme); + // } else { + // warn!("Dead Upstream : {}", link); + // } + } + inner.insert(path.clone(), (innervec, AtomicUsize::new(0))); + } + totest.insert(host.clone(), inner); + } + totest +} + +async fn http_request(url: &str, method: &str, payload: &str, client: &Client) -> (bool, bool) { + if !["POST", "GET", "HEAD"].contains(&method) { + error!("Method {} not supported. Only GET|POST|HEAD are supported ", method); + return (false, false); + } + async fn send_request(client: &Client, method: &str, url: &str, payload: &str) -> Option { + match method { + "POST" => client.post(url).body(payload.to_owned()).send().await.ok(), + "GET" => client.get(url).send().await.ok(), + "HEAD" => client.head(url).send().await.ok(), + _ => None, + } + } + + match send_request(&client, method, url, payload).await { + Some(response) => { + let status = response.status().as_u16(); + ((99..499).contains(&status), false) + } + None => (ping_grpc(&url).await, true), + } +} + +pub async fn ping_grpc(addr: &str) -> bool { + let endpoint_result = Endpoint::from_shared(addr.to_owned()); + + if let Ok(endpoint) = endpoint_result { + let endpoint = endpoint.timeout(Duration::from_secs(2)); + + match tokio::time::timeout(Duration::from_secs(3), endpoint.connect()).await { + Ok(Ok(_channel)) => true, + _ => false, + } + } else { + false + } +} + +async fn detect_tls(ip: &str, port: &u16, client: &Client) -> (bool, Option) { + let https_url = format!("https://{}:{}", ip, port); + match client.get(&https_url).send().await { + Ok(response) => { + // println!("{} => {:?} (HTTPS)", https_url, response.version()); + return (true, Some(response.version())); + } + _ => {} + } + let http_url = format!("http://{}:{}", ip, port); + match client.get(&http_url).send().await { + Ok(response) => { + // println!("{} => {:?} (HTTP)", http_url, response.version()); + (false, Some(response.version())) + } + Err(_) => { + if ping_grpc(&http_url).await { + (false, Some(Version::HTTP_2)) + } else { + (false, None) + } + } + } +} diff --git a/src/utils/http_utils.rs b/src/utils/http_utils.rs index 3f26d3e..2690218 100644 --- a/src/utils/http_utils.rs +++ b/src/utils/http_utils.rs @@ -1,86 +1,86 @@ -use std::net::IpAddr; - -/// Parse IP or CIDR notation into network and prefix length -pub fn parse_ip_or_cidr(entry: &str) -> Option<(IpAddr, u8)> { - let s = entry.trim(); - if s.is_empty() { - return None; - } - - if s.contains('/') { - let mut parts = s.split('/'); - let ip_str = parts.next()?.trim(); - let prefix_str = parts.next()?.trim(); - if parts.next().is_some() { - return None; // malformed - } - - let ip = ip_str.parse::().ok()?; - let prefix: u8 = prefix_str.parse().ok()?; - - // Validate prefix length - match ip { - IpAddr::V4(_) => { - if prefix > 32 { - return None; - } - } - IpAddr::V6(_) => { - if prefix > 128 { - return None; - } - } - } - - Some((ip, prefix)) - } else { - // Single IP address - let ip = s.parse::().ok()?; - let prefix = match ip { - IpAddr::V4(_) => 32, - IpAddr::V6(_) => 128, - }; - Some((ip, prefix)) - } -} - -/// Check if an IP address is within a CIDR range -pub fn is_ip_in_cidr(ip: IpAddr, network: IpAddr, prefix_len: u8) -> bool { - match (ip, network) { - (IpAddr::V4(ip), IpAddr::V4(network)) => { - let ip_u32 = u32::from(ip); - let net_u32 = u32::from(network); - let mask = if prefix_len == 0 { - 0 - } else { - u32::MAX.checked_shl((32 - prefix_len) as u32).unwrap_or(0) - }; - (ip_u32 & mask) == (net_u32 & mask) - } - (IpAddr::V6(ip), IpAddr::V6(network)) => { - let ip_bytes = ip.octets(); - let net_bytes = network.octets(); - let prefix_bytes = prefix_len / 8; - let remaining_bits = prefix_len % 8; - - // Check full bytes - for i in 0..prefix_bytes as usize { - if ip_bytes[i] != net_bytes[i] { - return false; - } - } - - // Check remaining bits - if remaining_bits > 0 && prefix_bytes < 16 { - let mask = 0xFF << (8 - remaining_bits); - if (ip_bytes[prefix_bytes as usize] & mask) != (net_bytes[prefix_bytes as usize] & mask) { - return false; - } - } - - true - } - _ => false, // Different IP versions - } -} - +use std::net::IpAddr; + +/// Parse IP or CIDR notation into network and prefix length +pub fn parse_ip_or_cidr(entry: &str) -> Option<(IpAddr, u8)> { + let s = entry.trim(); + if s.is_empty() { + return None; + } + + if s.contains('/') { + let mut parts = s.split('/'); + let ip_str = parts.next()?.trim(); + let prefix_str = parts.next()?.trim(); + if parts.next().is_some() { + return None; // malformed + } + + let ip = ip_str.parse::().ok()?; + let prefix: u8 = prefix_str.parse().ok()?; + + // Validate prefix length + match ip { + IpAddr::V4(_) => { + if prefix > 32 { + return None; + } + } + IpAddr::V6(_) => { + if prefix > 128 { + return None; + } + } + } + + Some((ip, prefix)) + } else { + // Single IP address + let ip = s.parse::().ok()?; + let prefix = match ip { + IpAddr::V4(_) => 32, + IpAddr::V6(_) => 128, + }; + Some((ip, prefix)) + } +} + +/// Check if an IP address is within a CIDR range +pub fn is_ip_in_cidr(ip: IpAddr, network: IpAddr, prefix_len: u8) -> bool { + match (ip, network) { + (IpAddr::V4(ip), IpAddr::V4(network)) => { + let ip_u32 = u32::from(ip); + let net_u32 = u32::from(network); + let mask = if prefix_len == 0 { + 0 + } else { + u32::MAX.checked_shl((32 - prefix_len) as u32).unwrap_or(0) + }; + (ip_u32 & mask) == (net_u32 & mask) + } + (IpAddr::V6(ip), IpAddr::V6(network)) => { + let ip_bytes = ip.octets(); + let net_bytes = network.octets(); + let prefix_bytes = prefix_len / 8; + let remaining_bits = prefix_len % 8; + + // Check full bytes + for i in 0..prefix_bytes as usize { + if ip_bytes[i] != net_bytes[i] { + return false; + } + } + + // Check remaining bits + if remaining_bits > 0 && prefix_bytes < 16 { + let mask = 0xFF << (8 - remaining_bits); + if (ip_bytes[prefix_bytes as usize] & mask) != (net_bytes[prefix_bytes as usize] & mask) { + return false; + } + } + + true + } + _ => false, // Different IP versions + } +} + diff --git a/src/utils/metrics.rs b/src/utils/metrics.rs index 37f8d6c..42c194e 100644 --- a/src/utils/metrics.rs +++ b/src/utils/metrics.rs @@ -1,83 +1,83 @@ -use pingora_http::Version; -use prometheus::{register_histogram, register_int_counter, register_int_counter_vec, Histogram, IntCounter, IntCounterVec}; -use std::time::Duration; -use once_cell::sync::Lazy; - -pub struct MetricTypes { - pub method: String, - pub code: String, - pub latency: Duration, - pub version: Version, -} - -pub static REQUEST_COUNT: Lazy = Lazy::new(|| { - register_int_counter!( - "synapse_requests_total", - "Total number of requests handled by Gen0Sec Synapse" - ).unwrap() -}); - -pub static RESPONSE_CODES: Lazy = Lazy::new(|| { - register_int_counter_vec!( - "synapse_responses_total", - "Responses grouped by status code by Gen0Sec Synapse", - &["status"] - ).unwrap() -}); - -pub static REQUEST_LATENCY: Lazy = Lazy::new(|| { - register_histogram!( - "synapse_request_latency_seconds", - "Request latency in seconds by Gen0Sec Synapse", - vec![0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0] - ).unwrap() -}); - -pub static RESPONSE_LATENCY: Lazy = Lazy::new(|| { - register_histogram!( - "synapse_response_latency_seconds", - "Response latency in seconds by Gen0Sec Synapse", - vec![0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.0, 5.0] - ).unwrap() -}); - -pub static REQUESTS_BY_METHOD: Lazy = Lazy::new(|| { - register_int_counter_vec!( - "synapse_requests_by_method_total", - "Number of requests by HTTP method by Gen0Sec Synapse", - &["method"] - ).unwrap() -}); - -pub static REQUESTS_BY_VERSION: Lazy = Lazy::new(|| { - register_int_counter_vec!( - "synapse_requests_by_version_total", - "Number of requests by HTTP versions by Gen0Sec Synapse", - &["version"] - ).unwrap() -}); - -pub static ERROR_COUNT: Lazy = Lazy::new(|| { - register_int_counter!( - "synapse_errors_total", - "Total number of errors by Gen0Sec Synapse" - ).unwrap() -}); - -pub fn calc_metrics(metric_types: &MetricTypes) { - REQUEST_COUNT.inc(); - let timer = REQUEST_LATENCY.start_timer(); - timer.observe_duration(); - - let version_str = match &metric_types.version { - &Version::HTTP_11 => "HTTP/1.1", - &Version::HTTP_2 => "HTTP/2.0", - &Version::HTTP_3 => "HTTP/3.0", - &Version::HTTP_10 => "HTTP/1.0", - _ => "Unknown", - }; - REQUESTS_BY_VERSION.with_label_values(&[&version_str]).inc(); - RESPONSE_CODES.with_label_values(&[&metric_types.code.to_string()]).inc(); - REQUESTS_BY_METHOD.with_label_values(&[&metric_types.method]).inc(); - RESPONSE_LATENCY.observe(metric_types.latency.as_secs_f64()); -} +use pingora_http::Version; +use prometheus::{register_histogram, register_int_counter, register_int_counter_vec, Histogram, IntCounter, IntCounterVec}; +use std::time::Duration; +use once_cell::sync::Lazy; + +pub struct MetricTypes { + pub method: String, + pub code: String, + pub latency: Duration, + pub version: Version, +} + +pub static REQUEST_COUNT: Lazy = Lazy::new(|| { + register_int_counter!( + "synapse_requests_total", + "Total number of requests handled by Gen0Sec Synapse" + ).unwrap() +}); + +pub static RESPONSE_CODES: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "synapse_responses_total", + "Responses grouped by status code by Gen0Sec Synapse", + &["status"] + ).unwrap() +}); + +pub static REQUEST_LATENCY: Lazy = Lazy::new(|| { + register_histogram!( + "synapse_request_latency_seconds", + "Request latency in seconds by Gen0Sec Synapse", + vec![0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0] + ).unwrap() +}); + +pub static RESPONSE_LATENCY: Lazy = Lazy::new(|| { + register_histogram!( + "synapse_response_latency_seconds", + "Response latency in seconds by Gen0Sec Synapse", + vec![0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.0, 5.0] + ).unwrap() +}); + +pub static REQUESTS_BY_METHOD: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "synapse_requests_by_method_total", + "Number of requests by HTTP method by Gen0Sec Synapse", + &["method"] + ).unwrap() +}); + +pub static REQUESTS_BY_VERSION: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "synapse_requests_by_version_total", + "Number of requests by HTTP versions by Gen0Sec Synapse", + &["version"] + ).unwrap() +}); + +pub static ERROR_COUNT: Lazy = Lazy::new(|| { + register_int_counter!( + "synapse_errors_total", + "Total number of errors by Gen0Sec Synapse" + ).unwrap() +}); + +pub fn calc_metrics(metric_types: &MetricTypes) { + REQUEST_COUNT.inc(); + let timer = REQUEST_LATENCY.start_timer(); + timer.observe_duration(); + + let version_str = match &metric_types.version { + &Version::HTTP_11 => "HTTP/1.1", + &Version::HTTP_2 => "HTTP/2.0", + &Version::HTTP_3 => "HTTP/3.0", + &Version::HTTP_10 => "HTTP/1.0", + _ => "Unknown", + }; + REQUESTS_BY_VERSION.with_label_values(&[&version_str]).inc(); + RESPONSE_CODES.with_label_values(&[&metric_types.code.to_string()]).inc(); + REQUESTS_BY_METHOD.with_label_values(&[&metric_types.method]).inc(); + RESPONSE_LATENCY.observe(metric_types.latency.as_secs_f64()); +} diff --git a/src/utils/parceyaml.rs b/src/utils/parceyaml.rs index 539c069..14536f6 100644 --- a/src/utils/parceyaml.rs +++ b/src/utils/parceyaml.rs @@ -1,316 +1,316 @@ -use crate::utils::healthcheck; -use crate::utils::state::{is_first_run, mark_not_first_run}; -use crate::utils::structs::*; -use crate::utils::tools::{clone_dashmap, clone_dashmap_into, print_upstreams}; -use dashmap::DashMap; -use log::{error, info, warn}; -use std::sync::atomic::AtomicUsize; -// use std::sync::mpsc::{channel, Receiver, Sender}; -use std::{env, fs}; -// use tokio::sync::oneshot::{Receiver, Sender}; - -pub async fn load_configuration(d: &str, kind: &str) -> Option { - let yaml_data = match kind { - "filepath" => match fs::read_to_string(d) { - Ok(data) => { - info!("Reading upstreams from {}", d); - data - } - Err(e) => { - error!("Reading: {}: {:?}", d, e); - warn!("Running with empty upstreams list, update it via API"); - return None; - } - }, - "content" => { - info!("Reading upstreams from API post body"); - d.to_string() - } - _ => { - error!("Mismatched parameter, only filepath|content is allowed"); - return None; - } - }; - - let parsed: Config = match serde_yaml::from_str(&yaml_data) { - Ok(cfg) => cfg, - Err(e) => { - error!("Failed to parse upstreams file: {}", e); - return None; - } - }; - - let mut toreturn = Configuration::default(); - - populate_headers_and_auth(&mut toreturn, &parsed).await; - toreturn.typecfg = parsed.provider.clone(); - - match parsed.provider.as_str() { - "file" => { - populate_file_upstreams(&mut toreturn, &parsed).await; - Some(toreturn) - } - "consul" => { - toreturn.consul = parsed.consul; - toreturn.consul.is_some().then_some(toreturn) - } - "kubernetes" => { - toreturn.kubernetes = parsed.kubernetes; - toreturn.kubernetes.is_some().then_some(toreturn) - } - _ => { - warn!("Unknown provider {}", parsed.provider); - None - } - } -} - -async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) { - // Handle new config format with nested config: section - if let Some(global_config) = &parsed.config { - // Use values from config: section if present - config.extraparams.sticky_sessions = global_config.sticky_sessions; - config.extraparams.https_proxy_enabled = Some(global_config.https_proxy_enabled); - config.extraparams.rate_limit = global_config.global_rate_limit; - - if let Some(headers) = &global_config.global_headers { - let mut hl = Vec::new(); - for header in headers { - if let Some((key, val)) = header.split_once(':') { - hl.push((key.trim().to_string(), val.trim().to_string())); - } - } - - let global_headers = DashMap::new(); - global_headers.insert("/".to_string(), hl); - config.headers.insert("GLOBAL_HEADERS".to_string(), global_headers); - } - - if let Some(rate) = &global_config.global_rate_limit { - info!("Applied Global Rate Limit : {} request per second", rate); - } - - // Store healthcheck settings from upstreams config - config.healthcheck_interval = global_config.healthcheck_interval; - config.healthcheck_method = global_config.healthcheck_method.clone(); - } else { - // Fallback to old format (top-level fields) - if let Some(headers) = &parsed.headers { - let mut hl = Vec::new(); - for header in headers { - if let Some((key, val)) = header.split_once(':') { - hl.push((key.trim().to_string(), val.trim().to_string())); - } - } - - let global_headers = DashMap::new(); - global_headers.insert("/".to_string(), hl); - config.headers.insert("GLOBAL_HEADERS".to_string(), global_headers); - } - - config.extraparams.sticky_sessions = parsed.sticky_sessions; - config.extraparams.https_proxy_enabled = None; // Legacy format doesn't have this - config.extraparams.rate_limit = parsed.rate_limit; - - if let Some(rate) = &parsed.rate_limit { - info!("Applied Global Rate Limit : {} request per second", rate); - } - } - - if let Some(auth) = &parsed.authorization { - let name = auth.get("type").unwrap_or(&"".to_string()).to_string(); - let creds = auth.get("creds").unwrap_or(&"".to_string()).to_string(); - config.extraparams.authentication.insert("authorization".to_string(), vec![name, creds]); - } else { - config.extraparams.authentication = DashMap::new(); - } -} - -async fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) { - // Handle arxignis_paths first - these are global paths that work across all hostnames - if let Some(arxignis_paths) = &parsed.arxignis_paths { - info!("Processing {} Arxignis paths", arxignis_paths.len()); - for (path, path_config) in arxignis_paths { - let mut server_list = Vec::new(); - for server in &path_config.servers { - if let Some((ip, port_str)) = server.split_once(':') { - if let Ok(port) = port_str.parse::() { - let https_proxy_enabled = path_config.https_proxy_enabled.unwrap_or(false); - let ssl_enabled = path_config.ssl_enabled.unwrap_or(true); - let http2_enabled = path_config.http2_enabled.unwrap_or(false); - let disable_access_log = path_config.disable_access_log.unwrap_or(false); - server_list.push(InnerMap { - address: ip.trim().to_string(), - port, - ssl_enabled, - http2_enabled, - https_proxy_enabled, - rate_limit: path_config.rate_limit, - healthcheck: path_config.healthcheck, - disable_access_log, - }); - } - } - } - config.arxignis_paths.insert(path.clone(), (server_list, AtomicUsize::new(0))); - info!("Arxignis path {} -> {} backend(s)", path, config.arxignis_paths.get(path).unwrap().0.len()); - } - } - - let imtdashmap = UpstreamsDashMap::new(); - if let Some(upstreams) = &parsed.upstreams { - for (hostname, host_config) in upstreams { - // Store certificate mapping if specified - if let Some(certificate_name) = &host_config.certificate { - config.certificates.insert(hostname.clone(), certificate_name.clone()); - info!("Upstream {} will use certificate: {}", hostname, certificate_name); - } - - let path_map = DashMap::new(); - let header_list = DashMap::new(); - for (path, path_config) in &host_config.paths { - if let Some(rate) = &path_config.rate_limit { - info!("Applied Rate Limit for {} : {} request per second", hostname, rate); - } - - let mut hl: Vec<(String, String)> = Vec::new(); - build_headers(&path_config.headers, config, &mut hl); - header_list.insert(path.clone(), hl); - - let mut server_list = Vec::new(); - for server in &path_config.servers { - if let Some((ip, port_str)) = server.split_once(':') { - if let Ok(port) = port_str.parse::() { - let https_proxy_enabled = path_config.https_proxy_enabled.unwrap_or(false); - let ssl_enabled = path_config.ssl_enabled.unwrap_or(true); // Default to SSL - let http2_enabled = path_config.http2_enabled.unwrap_or(false); // Default to HTTP/1.1 - let disable_access_log = path_config.disable_access_log.unwrap_or(false); - server_list.push(InnerMap { - address: ip.trim().to_string(), - port, - ssl_enabled, - http2_enabled, - https_proxy_enabled, - rate_limit: path_config.rate_limit, - healthcheck: path_config.healthcheck, - disable_access_log, - }); - } - } - } - path_map.insert(path.clone(), (server_list, AtomicUsize::new(0))); - } - config.headers.insert(hostname.clone(), header_list); - imtdashmap.insert(hostname.clone(), path_map); - } - - if is_first_run() { - clone_dashmap_into(&imtdashmap, &config.upstreams); - mark_not_first_run(); - } else { - let y = clone_dashmap(&imtdashmap); - let r = healthcheck::initiate_upstreams(y).await; - clone_dashmap_into(&r, &config.upstreams); - } - info!("Upstream Config:"); - print_upstreams(&config.upstreams); - } -} -pub fn parce_main_config(path: &str) -> AppConfig { - parce_main_config_with_log_level(path, None) -} - -pub fn parce_main_config_with_log_level(path: &str, log_level: Option<&str>) -> AppConfig { - let data = fs::read_to_string(path).unwrap(); - - if let Ok(new_config) = serde_yaml::from_str::(&data) { - log_builder(log_level); - return new_config.pingora.to_app_config(); - } - - let mut cfo: AppConfig = serde_yaml::from_str(&*data).expect("Failed to parse main config file"); - log_builder(log_level); - cfo.healthcheck_method = cfo.healthcheck_method.to_uppercase(); - if let Some((ip, port_str)) = cfo.config_address.split_once(':') { - if let Ok(port) = port_str.parse::() { - cfo.local_server = Option::from((ip.to_string(), port)); - } - } - if let Some(tlsport_cfg) = cfo.proxy_address_tls.clone() { - if let Some((_, port_str)) = tlsport_cfg.split_once(':') { - if let Ok(port) = port_str.parse::() { - cfo.proxy_port_tls = Some(port); - } - } - }; - cfo.proxy_tls_grade = parce_tls_grades(cfo.proxy_tls_grade.clone()); - cfo -} - -fn parce_tls_grades(what: Option) -> Option { - match what { - Some(g) => match g.to_ascii_lowercase().as_str() { - "high" => { - // info!("TLS grade set to: [ HIGH ]"); - Some("high".to_string()) - } - "medium" => { - // info!("TLS grade set to: [ MEDIUM ]"); - Some("medium".to_string()) - } - "unsafe" => { - // info!("TLS grade set to: [ UNSAFE ]"); - Some("unsafe".to_string()) - } - _ => { - warn!("Error parsing TLS grade, defaulting to: `medium`"); - Some("medium".to_string()) - } - }, - None => { - warn!("TLS grade not set, defaulting to: medium"); - Some("b".to_string()) - } - } -} - -fn log_builder(log_level: Option<&str>) { - // Use provided log level, or fall back to RUST_LOG env var, or default to "info" - let log_level = log_level - .map(|s| s.to_string()) - .or_else(|| std::env::var("RUST_LOG").ok()) - .unwrap_or_else(|| "info".to_string()); - unsafe { - match log_level.as_str() { - "info" => env::set_var("RUST_LOG", "info"), - "error" => env::set_var("RUST_LOG", "error"), - "warn" => env::set_var("RUST_LOG", "warn"), - "debug" => env::set_var("RUST_LOG", "debug"), - "trace" => env::set_var("RUST_LOG", "trace"), - "off" => env::set_var("RUST_LOG", "off"), - _ => { - println!("Error reading log level, defaulting to: INFO"); - env::set_var("RUST_LOG", "info") - } - } - } - // Use try_init() to avoid panic if logger is already initialized (e.g., from main.rs) - let _ = env_logger::builder().try_init(); -} - -pub fn build_headers(path_config: &Option>, config: &Configuration, hl: &mut Vec<(String, String)>) { - if let Some(headers) = &path_config { - for header in headers { - if let Some((key, val)) = header.split_once(':') { - hl.push((key.trim().to_string(), val.trim().to_string())); - } - } - if let Some(push) = config.headers.get("GLOBAL_HEADERS") { - for k in push.iter() { - for x in k.value() { - hl.push(x.to_owned()); - } - } - } - } -} +use crate::utils::healthcheck; +use crate::utils::state::{is_first_run, mark_not_first_run}; +use crate::utils::structs::*; +use crate::utils::tools::{clone_dashmap, clone_dashmap_into, print_upstreams}; +use dashmap::DashMap; +use log::{error, info, warn}; +use std::sync::atomic::AtomicUsize; +// use std::sync::mpsc::{channel, Receiver, Sender}; +use std::{env, fs}; +// use tokio::sync::oneshot::{Receiver, Sender}; + +pub async fn load_configuration(d: &str, kind: &str) -> Option { + let yaml_data = match kind { + "filepath" => match fs::read_to_string(d) { + Ok(data) => { + info!("Reading upstreams from {}", d); + data + } + Err(e) => { + error!("Reading: {}: {:?}", d, e); + warn!("Running with empty upstreams list, update it via API"); + return None; + } + }, + "content" => { + info!("Reading upstreams from API post body"); + d.to_string() + } + _ => { + error!("Mismatched parameter, only filepath|content is allowed"); + return None; + } + }; + + let parsed: Config = match serde_yaml::from_str(&yaml_data) { + Ok(cfg) => cfg, + Err(e) => { + error!("Failed to parse upstreams file: {}", e); + return None; + } + }; + + let mut toreturn = Configuration::default(); + + populate_headers_and_auth(&mut toreturn, &parsed).await; + toreturn.typecfg = parsed.provider.clone(); + + match parsed.provider.as_str() { + "file" => { + populate_file_upstreams(&mut toreturn, &parsed).await; + Some(toreturn) + } + "consul" => { + toreturn.consul = parsed.consul; + toreturn.consul.is_some().then_some(toreturn) + } + "kubernetes" => { + toreturn.kubernetes = parsed.kubernetes; + toreturn.kubernetes.is_some().then_some(toreturn) + } + _ => { + warn!("Unknown provider {}", parsed.provider); + None + } + } +} + +async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) { + // Handle new config format with nested config: section + if let Some(global_config) = &parsed.config { + // Use values from config: section if present + config.extraparams.sticky_sessions = global_config.sticky_sessions; + config.extraparams.https_proxy_enabled = Some(global_config.https_proxy_enabled); + config.extraparams.rate_limit = global_config.global_rate_limit; + + if let Some(headers) = &global_config.global_headers { + let mut hl = Vec::new(); + for header in headers { + if let Some((key, val)) = header.split_once(':') { + hl.push((key.trim().to_string(), val.trim().to_string())); + } + } + + let global_headers = DashMap::new(); + global_headers.insert("/".to_string(), hl); + config.headers.insert("GLOBAL_HEADERS".to_string(), global_headers); + } + + if let Some(rate) = &global_config.global_rate_limit { + info!("Applied Global Rate Limit : {} request per second", rate); + } + + // Store healthcheck settings from upstreams config + config.healthcheck_interval = global_config.healthcheck_interval; + config.healthcheck_method = global_config.healthcheck_method.clone(); + } else { + // Fallback to old format (top-level fields) + if let Some(headers) = &parsed.headers { + let mut hl = Vec::new(); + for header in headers { + if let Some((key, val)) = header.split_once(':') { + hl.push((key.trim().to_string(), val.trim().to_string())); + } + } + + let global_headers = DashMap::new(); + global_headers.insert("/".to_string(), hl); + config.headers.insert("GLOBAL_HEADERS".to_string(), global_headers); + } + + config.extraparams.sticky_sessions = parsed.sticky_sessions; + config.extraparams.https_proxy_enabled = None; // Legacy format doesn't have this + config.extraparams.rate_limit = parsed.rate_limit; + + if let Some(rate) = &parsed.rate_limit { + info!("Applied Global Rate Limit : {} request per second", rate); + } + } + + if let Some(auth) = &parsed.authorization { + let name = auth.get("type").unwrap_or(&"".to_string()).to_string(); + let creds = auth.get("creds").unwrap_or(&"".to_string()).to_string(); + config.extraparams.authentication.insert("authorization".to_string(), vec![name, creds]); + } else { + config.extraparams.authentication = DashMap::new(); + } +} + +async fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) { + // Handle arxignis_paths first - these are global paths that work across all hostnames + if let Some(arxignis_paths) = &parsed.arxignis_paths { + info!("Processing {} Arxignis paths", arxignis_paths.len()); + for (path, path_config) in arxignis_paths { + let mut server_list = Vec::new(); + for server in &path_config.servers { + if let Some((ip, port_str)) = server.split_once(':') { + if let Ok(port) = port_str.parse::() { + let https_proxy_enabled = path_config.https_proxy_enabled.unwrap_or(false); + let ssl_enabled = path_config.ssl_enabled.unwrap_or(true); + let http2_enabled = path_config.http2_enabled.unwrap_or(false); + let disable_access_log = path_config.disable_access_log.unwrap_or(false); + server_list.push(InnerMap { + address: ip.trim().to_string(), + port, + ssl_enabled, + http2_enabled, + https_proxy_enabled, + rate_limit: path_config.rate_limit, + healthcheck: path_config.healthcheck, + disable_access_log, + }); + } + } + } + config.arxignis_paths.insert(path.clone(), (server_list, AtomicUsize::new(0))); + info!("Arxignis path {} -> {} backend(s)", path, config.arxignis_paths.get(path).unwrap().0.len()); + } + } + + let imtdashmap = UpstreamsDashMap::new(); + if let Some(upstreams) = &parsed.upstreams { + for (hostname, host_config) in upstreams { + // Store certificate mapping if specified + if let Some(certificate_name) = &host_config.certificate { + config.certificates.insert(hostname.clone(), certificate_name.clone()); + info!("Upstream {} will use certificate: {}", hostname, certificate_name); + } + + let path_map = DashMap::new(); + let header_list = DashMap::new(); + for (path, path_config) in &host_config.paths { + if let Some(rate) = &path_config.rate_limit { + info!("Applied Rate Limit for {} : {} request per second", hostname, rate); + } + + let mut hl: Vec<(String, String)> = Vec::new(); + build_headers(&path_config.headers, config, &mut hl); + header_list.insert(path.clone(), hl); + + let mut server_list = Vec::new(); + for server in &path_config.servers { + if let Some((ip, port_str)) = server.split_once(':') { + if let Ok(port) = port_str.parse::() { + let https_proxy_enabled = path_config.https_proxy_enabled.unwrap_or(false); + let ssl_enabled = path_config.ssl_enabled.unwrap_or(true); // Default to SSL + let http2_enabled = path_config.http2_enabled.unwrap_or(false); // Default to HTTP/1.1 + let disable_access_log = path_config.disable_access_log.unwrap_or(false); + server_list.push(InnerMap { + address: ip.trim().to_string(), + port, + ssl_enabled, + http2_enabled, + https_proxy_enabled, + rate_limit: path_config.rate_limit, + healthcheck: path_config.healthcheck, + disable_access_log, + }); + } + } + } + path_map.insert(path.clone(), (server_list, AtomicUsize::new(0))); + } + config.headers.insert(hostname.clone(), header_list); + imtdashmap.insert(hostname.clone(), path_map); + } + + if is_first_run() { + clone_dashmap_into(&imtdashmap, &config.upstreams); + mark_not_first_run(); + } else { + let y = clone_dashmap(&imtdashmap); + let r = healthcheck::initiate_upstreams(y).await; + clone_dashmap_into(&r, &config.upstreams); + } + info!("Upstream Config:"); + print_upstreams(&config.upstreams); + } +} +pub fn parce_main_config(path: &str) -> AppConfig { + parce_main_config_with_log_level(path, None) +} + +pub fn parce_main_config_with_log_level(path: &str, log_level: Option<&str>) -> AppConfig { + let data = fs::read_to_string(path).unwrap(); + + if let Ok(new_config) = serde_yaml::from_str::(&data) { + log_builder(log_level); + return new_config.pingora.to_app_config(); + } + + let mut cfo: AppConfig = serde_yaml::from_str(&*data).expect("Failed to parse main config file"); + log_builder(log_level); + cfo.healthcheck_method = cfo.healthcheck_method.to_uppercase(); + if let Some((ip, port_str)) = cfo.config_address.split_once(':') { + if let Ok(port) = port_str.parse::() { + cfo.local_server = Option::from((ip.to_string(), port)); + } + } + if let Some(tlsport_cfg) = cfo.proxy_address_tls.clone() { + if let Some((_, port_str)) = tlsport_cfg.split_once(':') { + if let Ok(port) = port_str.parse::() { + cfo.proxy_port_tls = Some(port); + } + } + }; + cfo.proxy_tls_grade = parce_tls_grades(cfo.proxy_tls_grade.clone()); + cfo +} + +fn parce_tls_grades(what: Option) -> Option { + match what { + Some(g) => match g.to_ascii_lowercase().as_str() { + "high" => { + // info!("TLS grade set to: [ HIGH ]"); + Some("high".to_string()) + } + "medium" => { + // info!("TLS grade set to: [ MEDIUM ]"); + Some("medium".to_string()) + } + "unsafe" => { + // info!("TLS grade set to: [ UNSAFE ]"); + Some("unsafe".to_string()) + } + _ => { + warn!("Error parsing TLS grade, defaulting to: `medium`"); + Some("medium".to_string()) + } + }, + None => { + warn!("TLS grade not set, defaulting to: medium"); + Some("b".to_string()) + } + } +} + +fn log_builder(log_level: Option<&str>) { + // Use provided log level, or fall back to RUST_LOG env var, or default to "info" + let log_level = log_level + .map(|s| s.to_string()) + .or_else(|| std::env::var("RUST_LOG").ok()) + .unwrap_or_else(|| "info".to_string()); + unsafe { + match log_level.as_str() { + "info" => env::set_var("RUST_LOG", "info"), + "error" => env::set_var("RUST_LOG", "error"), + "warn" => env::set_var("RUST_LOG", "warn"), + "debug" => env::set_var("RUST_LOG", "debug"), + "trace" => env::set_var("RUST_LOG", "trace"), + "off" => env::set_var("RUST_LOG", "off"), + _ => { + println!("Error reading log level, defaulting to: INFO"); + env::set_var("RUST_LOG", "info") + } + } + } + // Use try_init() to avoid panic if logger is already initialized (e.g., from main.rs) + let _ = env_logger::builder().try_init(); +} + +pub fn build_headers(path_config: &Option>, config: &Configuration, hl: &mut Vec<(String, String)>) { + if let Some(headers) = &path_config { + for header in headers { + if let Some((key, val)) = header.split_once(':') { + hl.push((key.trim().to_string(), val.trim().to_string())); + } + } + if let Some(push) = config.headers.get("GLOBAL_HEADERS") { + for k in push.iter() { + for x in k.value() { + hl.push(x.to_owned()); + } + } + } + } +} diff --git a/src/utils/state.rs b/src/utils/state.rs index 075bfa3..0cdfb6b 100644 --- a/src/utils/state.rs +++ b/src/utils/state.rs @@ -1,19 +1,19 @@ -use once_cell::sync::Lazy; -use std::sync::RwLock; - -#[derive(Debug)] -pub struct SharedState { - pub first_run: bool, -} - -pub static GLOBAL_STATE: Lazy> = Lazy::new(|| RwLock::new(SharedState { first_run: true })); - -pub fn mark_not_first_run() { - let mut state = GLOBAL_STATE.write().unwrap(); - state.first_run = false; -} - -pub fn is_first_run() -> bool { - let state = GLOBAL_STATE.read().unwrap(); - state.first_run -} +use once_cell::sync::Lazy; +use std::sync::RwLock; + +#[derive(Debug)] +pub struct SharedState { + pub first_run: bool, +} + +pub static GLOBAL_STATE: Lazy> = Lazy::new(|| RwLock::new(SharedState { first_run: true })); + +pub fn mark_not_first_run() { + let mut state = GLOBAL_STATE.write().unwrap(); + state.first_run = false; +} + +pub fn is_first_run() -> bool { + let state = GLOBAL_STATE.read().unwrap(); + state.first_run +} diff --git a/src/utils/structs.rs b/src/utils/structs.rs index 3a2b80f..736a544 100644 --- a/src/utils/structs.rs +++ b/src/utils/structs.rs @@ -1,187 +1,187 @@ -use dashmap::DashMap; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::sync::atomic::AtomicUsize; - -pub type UpstreamsDashMap = DashMap, AtomicUsize)>>; - -pub type UpstreamsIdMap = DashMap; -pub type Headers = DashMap>>; - -#[derive(Debug, Default, Clone, Serialize, Deserialize)] -pub struct ServiceMapping { - pub upstream: String, - pub hostname: String, - pub path: Option, - #[serde(default)] - pub force_https: Option, - pub rate_limit: Option, - pub headers: Option>, -} - -// pub type Services = DashMap)>>; - -#[derive(Clone, Debug, Default)] -pub struct Extraparams { - pub sticky_sessions: bool, - pub https_proxy_enabled: Option, - pub authentication: DashMap>, - pub rate_limit: Option, -} -#[derive(Clone, Default, Debug, Serialize, Deserialize)] -pub struct Kubernetes { - pub servers: Option>, - pub services: Option>, - pub tokenpath: Option, -} - -#[derive(Clone, Default, Debug, Serialize, Deserialize)] -pub struct Consul { - pub servers: Option>, - pub services: Option>, - pub token: Option, -} -#[derive(Debug, Default, Serialize, Deserialize)] -pub struct GlobalConfig { - #[serde(default)] - pub https_proxy_enabled: bool, - #[serde(default)] - pub sticky_sessions: bool, - #[serde(default)] - pub global_rate_limit: Option, - #[serde(default)] - pub global_headers: Option>, - #[serde(default)] - pub healthcheck_interval: Option, - #[serde(default)] - pub healthcheck_method: Option, -} - -#[derive(Debug, Default, Serialize, Deserialize)] -pub struct Config { - #[serde(default = "default_provider")] - pub provider: String, - #[serde(default)] - pub config: Option, - #[serde(default)] - pub sticky_sessions: bool, - #[serde(default)] - pub arxignis_paths: Option>, - #[serde(default)] - pub upstreams: Option>, - #[serde(default)] - pub globals: Option>>, - #[serde(default)] - pub headers: Option>, - #[serde(default)] - pub authorization: Option>, - #[serde(default)] - pub consul: Option, - #[serde(default)] - pub kubernetes: Option, - #[serde(default)] - pub rate_limit: Option, -} - -fn default_provider() -> String { - "file".to_string() -} - -#[derive(Debug, Default, Serialize, Deserialize)] -pub struct HostConfig { - pub paths: HashMap, - pub rate_limit: Option, - #[serde(default)] - pub certificate: Option, - #[cfg(feature = "proxy")] - #[serde(default)] - pub acme: Option, -} - -impl HostConfig { - /// Check if a domain needs a certificate to be automatically requested via ACME - /// Only returns true if there's an explicit ACME configuration block. - /// If ssl_enabled is true but no ACME config exists, the user is expected to provide certificates manually. - pub fn needs_certificate(&self) -> bool { - // Only request certificates if ACME is explicitly configured - #[cfg(feature = "proxy")] - { - self.acme.is_some() - } - #[cfg(not(feature = "proxy"))] - { - false - } - } -} - -#[derive(Debug, Default, Serialize, Deserialize)] -pub struct PathConfig { - pub servers: Vec, - #[serde(default, alias = "force_https")] - pub https_proxy_enabled: Option, - #[serde(default)] - pub ssl_enabled: Option, - #[serde(default)] - pub http2_enabled: Option, - #[serde(default)] - pub headers: Option>, - #[serde(default)] - pub rate_limit: Option, - #[serde(default)] - pub healthcheck: Option, - #[serde(default)] - pub disable_access_log: Option, -} -#[derive(Debug, Default)] -pub struct Configuration { - pub arxignis_paths: DashMap, AtomicUsize)>, - pub upstreams: UpstreamsDashMap, - pub headers: Headers, - pub consul: Option, - pub kubernetes: Option, - pub typecfg: String, - pub extraparams: Extraparams, - pub certificates: DashMap, // hostname -> certificate_name mapping - pub healthcheck_interval: Option, - pub healthcheck_method: Option, -} - -#[derive(Debug, Default, Serialize, Deserialize)] -#[serde(default)] -pub struct AppConfig { - pub healthcheck_interval: u16, - pub healthcheck_method: String, - pub master_key: String, - pub upstreams_conf: String, - pub config_address: String, - pub proxy_address_http: String, - pub config_api_enabled: bool, - pub config_tls_address: Option, - pub config_tls_certificate: Option, - pub config_tls_key_file: Option, - pub proxy_address_tls: Option, - pub proxy_port_tls: Option, - pub local_server: Option<(String, u16)>, - pub proxy_certificates: Option, - pub proxy_tls_grade: Option, - pub default_certificate: Option, - pub file_server_address: Option, - pub file_server_folder: Option, - pub runuser: Option, - pub rungroup: Option, - pub proxy_protocol_enabled: bool, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct InnerMap { - pub address: String, - pub port: u16, - pub ssl_enabled: bool, - pub http2_enabled: bool, - pub https_proxy_enabled: bool, - pub rate_limit: Option, - pub healthcheck: Option, - pub disable_access_log: bool, -} - +use dashmap::DashMap; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::atomic::AtomicUsize; + +pub type UpstreamsDashMap = DashMap, AtomicUsize)>>; + +pub type UpstreamsIdMap = DashMap; +pub type Headers = DashMap>>; + +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct ServiceMapping { + pub upstream: String, + pub hostname: String, + pub path: Option, + #[serde(default)] + pub force_https: Option, + pub rate_limit: Option, + pub headers: Option>, +} + +// pub type Services = DashMap)>>; + +#[derive(Clone, Debug, Default)] +pub struct Extraparams { + pub sticky_sessions: bool, + pub https_proxy_enabled: Option, + pub authentication: DashMap>, + pub rate_limit: Option, +} +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct Kubernetes { + pub servers: Option>, + pub services: Option>, + pub tokenpath: Option, +} + +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct Consul { + pub servers: Option>, + pub services: Option>, + pub token: Option, +} +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct GlobalConfig { + #[serde(default)] + pub https_proxy_enabled: bool, + #[serde(default)] + pub sticky_sessions: bool, + #[serde(default)] + pub global_rate_limit: Option, + #[serde(default)] + pub global_headers: Option>, + #[serde(default)] + pub healthcheck_interval: Option, + #[serde(default)] + pub healthcheck_method: Option, +} + +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct Config { + #[serde(default = "default_provider")] + pub provider: String, + #[serde(default)] + pub config: Option, + #[serde(default)] + pub sticky_sessions: bool, + #[serde(default)] + pub arxignis_paths: Option>, + #[serde(default)] + pub upstreams: Option>, + #[serde(default)] + pub globals: Option>>, + #[serde(default)] + pub headers: Option>, + #[serde(default)] + pub authorization: Option>, + #[serde(default)] + pub consul: Option, + #[serde(default)] + pub kubernetes: Option, + #[serde(default)] + pub rate_limit: Option, +} + +fn default_provider() -> String { + "file".to_string() +} + +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct HostConfig { + pub paths: HashMap, + pub rate_limit: Option, + #[serde(default)] + pub certificate: Option, + #[cfg(feature = "proxy")] + #[serde(default)] + pub acme: Option, +} + +impl HostConfig { + /// Check if a domain needs a certificate to be automatically requested via ACME + /// Only returns true if there's an explicit ACME configuration block. + /// If ssl_enabled is true but no ACME config exists, the user is expected to provide certificates manually. + pub fn needs_certificate(&self) -> bool { + // Only request certificates if ACME is explicitly configured + #[cfg(feature = "proxy")] + { + self.acme.is_some() + } + #[cfg(not(feature = "proxy"))] + { + false + } + } +} + +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct PathConfig { + pub servers: Vec, + #[serde(default, alias = "force_https")] + pub https_proxy_enabled: Option, + #[serde(default)] + pub ssl_enabled: Option, + #[serde(default)] + pub http2_enabled: Option, + #[serde(default)] + pub headers: Option>, + #[serde(default)] + pub rate_limit: Option, + #[serde(default)] + pub healthcheck: Option, + #[serde(default)] + pub disable_access_log: Option, +} +#[derive(Debug, Default)] +pub struct Configuration { + pub arxignis_paths: DashMap, AtomicUsize)>, + pub upstreams: UpstreamsDashMap, + pub headers: Headers, + pub consul: Option, + pub kubernetes: Option, + pub typecfg: String, + pub extraparams: Extraparams, + pub certificates: DashMap, // hostname -> certificate_name mapping + pub healthcheck_interval: Option, + pub healthcheck_method: Option, +} + +#[derive(Debug, Default, Serialize, Deserialize)] +#[serde(default)] +pub struct AppConfig { + pub healthcheck_interval: u16, + pub healthcheck_method: String, + pub master_key: String, + pub upstreams_conf: String, + pub config_address: String, + pub proxy_address_http: String, + pub config_api_enabled: bool, + pub config_tls_address: Option, + pub config_tls_certificate: Option, + pub config_tls_key_file: Option, + pub proxy_address_tls: Option, + pub proxy_port_tls: Option, + pub local_server: Option<(String, u16)>, + pub proxy_certificates: Option, + pub proxy_tls_grade: Option, + pub default_certificate: Option, + pub file_server_address: Option, + pub file_server_folder: Option, + pub runuser: Option, + pub rungroup: Option, + pub proxy_protocol_enabled: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct InnerMap { + pub address: String, + pub port: u16, + pub ssl_enabled: bool, + pub http2_enabled: bool, + pub https_proxy_enabled: bool, + pub rate_limit: Option, + pub healthcheck: Option, + pub disable_access_log: bool, +} + diff --git a/src/utils/tcp_fingerprint.rs b/src/utils/tcp_fingerprint.rs index 6d34ab1..9742484 100644 --- a/src/utils/tcp_fingerprint.rs +++ b/src/utils/tcp_fingerprint.rs @@ -1,1118 +1,1118 @@ -use std::sync::Arc; -use serde::{Deserialize, Serialize}; -use chrono::{DateTime, Utc}; -use std::net::Ipv4Addr; -use libbpf_rs::MapCore; -use crate::worker::log::{send_event, UnifiedEvent}; - -use crate::bpf::FilterSkel; - -/// TCP fingerprinting configuration -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TcpFingerprintConfig { - pub enabled: bool, - pub log_interval_secs: u64, - pub enable_fingerprint_events: bool, - pub fingerprint_events_interval_secs: u64, - pub min_packet_count: u32, - pub min_connection_duration_secs: u64, -} - -impl Default for TcpFingerprintConfig { - fn default() -> Self { - Self { - enabled: true, - log_interval_secs: 60, - enable_fingerprint_events: true, - fingerprint_events_interval_secs: 30, - min_packet_count: 3, - min_connection_duration_secs: 1, - } - } -} - -impl TcpFingerprintConfig { - /// Convert from CLI configuration - pub fn from_cli_config(cli_config: &crate::cli::TcpFingerprintConfig) -> Self { - Self { - enabled: cli_config.enabled, - log_interval_secs: cli_config.log_interval_secs, - enable_fingerprint_events: cli_config.enable_fingerprint_events, - fingerprint_events_interval_secs: cli_config.fingerprint_events_interval_secs, - min_packet_count: cli_config.min_packet_count, - min_connection_duration_secs: cli_config.min_connection_duration_secs, - } - } -} - -/// TCP fingerprint data collected from BPF -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TcpFingerprintData { - pub first_seen: DateTime, - pub last_seen: DateTime, - pub packet_count: u32, - pub ttl: u16, - pub mss: u16, - pub window_size: u16, - pub window_scale: u8, - pub options_len: u8, - pub options: Vec, -} - -/// TCP fingerprint key -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TcpFingerprintKey { - pub src_ip: String, - pub src_port: u16, - pub fingerprint: String, -} - -/// TCP fingerprint entry -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TcpFingerprintEntry { - pub key: TcpFingerprintKey, - pub data: TcpFingerprintData, -} - -/// TCP SYN statistics -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TcpSynStats { - pub total_syns: u64, - pub unique_fingerprints: u64, - pub last_reset: DateTime, -} - -/// TCP fingerprinting statistics -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TcpFingerprintStats { - pub timestamp: DateTime, - pub syn_stats: TcpSynStats, - pub fingerprints: Vec, - pub total_unique_fingerprints: u64, -} - -/// TCP fingerprint event for API -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TcpFingerprintEvent { - pub event_type: String, - pub timestamp: DateTime, - pub src_ip: String, - pub src_port: u16, - pub fingerprint: String, - pub ttl: u16, - pub mss: u16, - pub window_size: u16, - pub window_scale: u8, - pub packet_count: u32 -} - -/// Collection of TCP fingerprint events -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TcpFingerprintEvents { - pub events: Vec, - pub total_events: u64, - pub unique_ips: u64, -} - -/// Unique fingerprint pattern statistics -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UniqueFingerprintPattern { - pub pattern: String, - pub packet_count: u32, - pub unique_ips: usize, - pub entries: usize, -} - -/// Unique fingerprint statistics -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UniqueFingerprintStats { - pub timestamp: DateTime, - pub total_unique_patterns: usize, - pub total_unique_ips: usize, - pub total_packets: u32, - pub patterns: Vec, -} - -impl TcpFingerprintEvents { - /// Get top fingerprints by packet count - pub fn get_top_fingerprints(&self, limit: usize) -> Vec { - let mut events = self.events.clone(); - events.sort_by(|a, b| b.packet_count.cmp(&a.packet_count)); - events.into_iter().take(limit).collect() - } - - /// Convert to JSON string - pub fn to_json(&self) -> Result { - serde_json::to_string(self) - } - - /// Generate summary string - pub fn summary(&self) -> String { - format!("TCP Fingerprint Events: {} events from {} unique IPs", - self.total_events, self.unique_ips) - } -} - -impl TcpFingerprintEvent { - /// Generate summary string - pub fn summary(&self) -> String { - format!("TCP Fingerprint: {}:{} {} (TTL:{}, MSS:{}, Window:{}, Scale:{}, Packets:{})", - self.src_ip, self.src_port, self.fingerprint, - self.ttl, self.mss, self.window_size, self.window_scale, self.packet_count) - } -} - -impl UniqueFingerprintStats { - /// Generate summary string - pub fn summary(&self) -> String { - format!("Unique Fingerprint Stats: {} patterns, {} unique IPs, {} total packets", - self.total_unique_patterns, self.total_unique_ips, self.total_packets) - } - - /// Convert to JSON string - pub fn to_json(&self) -> Result { - serde_json::to_string(self) - } -} - -impl TcpFingerprintStats { - /// Generate summary string - pub fn summary(&self) -> String { - let mut summary = format!("TCP Fingerprint Stats: {} SYN packets, {} unique fingerprints, {} total entries", - self.syn_stats.total_syns, self.syn_stats.unique_fingerprints, self.total_unique_fingerprints); - - // Add top unique fingerprints if any - if !self.fingerprints.is_empty() { - summary.push_str(&format!(", {} unique fingerprints found", self.fingerprints.len())); - - // Show top 5 fingerprints by packet count - let mut fingerprint_vec: Vec<_> = self.fingerprints.iter().collect(); - fingerprint_vec.sort_by(|a, b| b.data.packet_count.cmp(&a.data.packet_count)); - - if !fingerprint_vec.is_empty() { - summary.push_str(", Top fingerprints: "); - for (i, entry) in fingerprint_vec.iter().take(5).enumerate() { - if i > 0 { summary.push_str(", "); } - summary.push_str(&format!("{}:{}:{}:{}", - entry.key.src_ip, entry.key.src_port, entry.key.fingerprint, entry.data.packet_count)); - } - } - } - - summary - } -} - -/// Global TCP fingerprint collector -static TCP_FINGERPRINT_COLLECTOR: std::sync::OnceLock> = std::sync::OnceLock::new(); - -/// Set the global TCP fingerprint collector -pub fn set_global_tcp_fingerprint_collector(collector: TcpFingerprintCollector) { - let _ = TCP_FINGERPRINT_COLLECTOR.set(Arc::new(collector)); -} - -/// Get the global TCP fingerprint collector -pub fn get_global_tcp_fingerprint_collector() -> Option> { - TCP_FINGERPRINT_COLLECTOR.get().cloned() -} - -/// TCP fingerprint collector -#[derive(Clone)] -pub struct TcpFingerprintCollector { - skels: Vec>>, - enabled: bool, - config: TcpFingerprintConfig, -} - -impl TcpFingerprintCollector { - /// Create a new TCP fingerprint collector - pub fn new(skels: Vec>>, enabled: bool) -> Self { - Self { - skels, - enabled, - config: TcpFingerprintConfig::default(), - } - } - - /// Create a new TCP fingerprint collector with configuration - pub fn new_with_config(skels: Vec>>, config: TcpFingerprintConfig) -> Self { - Self { - skels, - enabled: config.enabled, - config, - } - } - - /// Enable or disable fingerprint collection - pub fn set_enabled(&mut self, enabled: bool) { - self.enabled = enabled; - } - - /// Check if fingerprint collection is enabled - pub fn is_enabled(&self) -> bool { - self.enabled - } - - /// Lookup TCP fingerprint for a specific source IP and port - pub fn lookup_fingerprint(&self, src_ip: std::net::IpAddr, src_port: u16) -> Option { - if !self.enabled || self.skels.is_empty() { - return None; - } - - match src_ip { - std::net::IpAddr::V4(ip) => { - let octets = ip.octets(); - let src_ip_be = u32::from_be_bytes(octets); - - // Try to find fingerprint in any skeleton's IPv4 map - for skel in &self.skels { - if let Ok(iter) = skel.maps.tcp_fingerprints.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { - for (key_bytes, value_bytes) in iter { - if key_bytes.len() >= 6 && value_bytes.len() >= 32 { - // Parse key structure: src_ip (4 bytes BE), src_port (2 bytes BE), fingerprint (14 bytes) - // BPF stores IP as __be32 (big-endian), so read as big-endian - let key_ip = u32::from_be_bytes([key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3]]); - let key_port = u16::from_be_bytes([key_bytes[4], key_bytes[5]]); - - if key_ip == src_ip_be && key_port == src_port { - // Parse value structure - if value_bytes.len() >= 32 { - let first_seen = u64::from_ne_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7] - ]); - let last_seen = u64::from_ne_bytes([ - value_bytes[8], value_bytes[9], value_bytes[10], value_bytes[11], - value_bytes[12], value_bytes[13], value_bytes[14], value_bytes[15] - ]); - let packet_count = u32::from_ne_bytes([ - value_bytes[16], value_bytes[17], value_bytes[18], value_bytes[19] - ]); - let ttl = u16::from_ne_bytes([value_bytes[20], value_bytes[21]]); - let mss = u16::from_ne_bytes([value_bytes[22], value_bytes[23]]); - let window_size = u16::from_ne_bytes([value_bytes[24], value_bytes[25]]); - let window_scale = value_bytes[26]; - let options_len = value_bytes[27]; - - let options_size = options_len.min(16) as usize; - let mut options = vec![0u8; options_size]; - if value_bytes.len() >= 28 + options_size { - options.copy_from_slice(&value_bytes[28..28 + options_size]); - } - - return Some(TcpFingerprintData { - first_seen: DateTime::from_timestamp_nanos(first_seen as i64), - last_seen: DateTime::from_timestamp_nanos(last_seen as i64), - packet_count, - ttl, - mss, - window_size, - window_scale, - options_len, - options, - }); - } - } - } - } - } - } - None - } - std::net::IpAddr::V6(ip) => { - let octets = ip.octets(); - - // Try to find fingerprint in any skeleton's IPv6 map - for skel in &self.skels { - if let Ok(iter) = skel.maps.tcp_fingerprints_v6.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { - for (key_bytes, value_bytes) in iter { - if key_bytes.len() >= 18 && value_bytes.len() >= 32 { - // Parse key structure: src_ip (16 bytes), src_port (2 bytes BE), fingerprint (14 bytes) - let mut key_ip = [0u8; 16]; - key_ip.copy_from_slice(&key_bytes[0..16]); - let key_port = u16::from_be_bytes([key_bytes[16], key_bytes[17]]); - - if key_ip == octets && key_port == src_port { - // Parse value structure (same as IPv4) - if value_bytes.len() >= 32 { - let first_seen = u64::from_ne_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7] - ]); - let last_seen = u64::from_ne_bytes([ - value_bytes[8], value_bytes[9], value_bytes[10], value_bytes[11], - value_bytes[12], value_bytes[13], value_bytes[14], value_bytes[15] - ]); - let packet_count = u32::from_ne_bytes([ - value_bytes[16], value_bytes[17], value_bytes[18], value_bytes[19] - ]); - let ttl = u16::from_ne_bytes([value_bytes[20], value_bytes[21]]); - let mss = u16::from_ne_bytes([value_bytes[22], value_bytes[23]]); - let window_size = u16::from_ne_bytes([value_bytes[24], value_bytes[25]]); - let window_scale = value_bytes[26]; - let options_len = value_bytes[27]; - - let options_size = options_len.min(16) as usize; - let mut options = vec![0u8; options_size]; - if value_bytes.len() >= 28 + options_size { - options.copy_from_slice(&value_bytes[28..28 + options_size]); - } - - return Some(TcpFingerprintData { - first_seen: DateTime::from_timestamp_nanos(first_seen as i64), - last_seen: DateTime::from_timestamp_nanos(last_seen as i64), - packet_count, - ttl, - mss, - window_size, - window_scale, - options_len, - options, - }); - } - } - } - } - } - } - None - } - } - } - - /// Collect TCP fingerprint statistics from all BPF skeletons - pub fn collect_fingerprint_stats(&self) -> Result, Box> { - if !self.enabled { - return Ok(vec![]); - } - - let mut stats = Vec::new(); - for (i, skel) in self.skels.iter().enumerate() { - log::debug!("Collecting TCP fingerprint stats from skeleton {}", i); - match self.collect_fingerprint_stats_from_skeleton(skel) { - Ok(stat) => { - log::debug!("Skeleton {} collected {} fingerprints", i, stat.fingerprints.len()); - stats.push(stat); - } - Err(e) => { - log::warn!("Failed to collect TCP fingerprint stats from skeleton {}: {}", i, e); - } - } - } - log::debug!("Collected stats from {} skeletons", stats.len()); - Ok(stats) - } - - /// Collect TCP fingerprint statistics from a single BPF skeleton - fn collect_fingerprint_stats_from_skeleton(&self, skel: &FilterSkel) -> Result> { - if !self.enabled { - return Ok(TcpFingerprintStats { - timestamp: Utc::now(), - syn_stats: TcpSynStats { - total_syns: 0, - unique_fingerprints: 0, - last_reset: Utc::now(), - }, - fingerprints: Vec::new(), - total_unique_fingerprints: 0, - }); - } - - let mut fingerprints = Vec::new(); - - // Read TCP SYN statistics - log::debug!("Reading TCP SYN statistics from skeleton"); - let syn_stats = self.collect_syn_stats(skel)?; - log::debug!("TCP SYN stats: {} total_syns, {} unique_fingerprints", syn_stats.total_syns, syn_stats.unique_fingerprints); - - // Read TCP fingerprints from BPF map - log::debug!("Reading TCP fingerprints from skeleton"); - self.collect_tcp_fingerprints(skel, &mut fingerprints)?; - log::debug!("Collected {} fingerprints from skeleton", fingerprints.len()); - - let total_unique_fingerprints = fingerprints.len() as u64; - - Ok(TcpFingerprintStats { - timestamp: Utc::now(), - syn_stats, - fingerprints, - total_unique_fingerprints, - }) - } - - /// Collect aggregated TCP fingerprint statistics across all skeletons - pub fn collect_aggregated_stats(&self) -> Result> { - if !self.enabled { - return Err("TCP fingerprint collection is disabled".into()); - } - - log::debug!("Collecting aggregated TCP fingerprint statistics from {} skeletons", self.skels.len()); - let individual_stats = self.collect_fingerprint_stats()?; - log::debug!("Collected {} individual stats", individual_stats.len()); - - if individual_stats.is_empty() { - log::warn!("No TCP fingerprint statistics available from any skeleton"); - return Err("No TCP fingerprint statistics available".into()); - } - - // Aggregate statistics across all skeletons - let mut aggregated = TcpFingerprintStats { - timestamp: Utc::now(), - syn_stats: TcpSynStats { - total_syns: 0, - unique_fingerprints: 0, - last_reset: Utc::now(), - }, - fingerprints: Vec::new(), - total_unique_fingerprints: 0, - }; - - let mut all_fingerprints: std::collections::HashMap = std::collections::HashMap::new(); - - for stat in individual_stats { - aggregated.syn_stats.total_syns += stat.syn_stats.total_syns; - aggregated.syn_stats.unique_fingerprints += stat.syn_stats.unique_fingerprints; - - // Merge fingerprints by key (src_ip:src_port:fingerprint) - for entry in stat.fingerprints { - let key = format!("{}:{}:{}", entry.key.src_ip, entry.key.src_port, entry.key.fingerprint); - match all_fingerprints.get_mut(&key) { - Some(existing) => { - // Update packet count and timestamps - existing.data.packet_count += entry.data.packet_count; - if entry.data.first_seen < existing.data.first_seen { - existing.data.first_seen = entry.data.first_seen; - } - if entry.data.last_seen > existing.data.last_seen { - existing.data.last_seen = entry.data.last_seen; - } - } - None => { - all_fingerprints.insert(key, entry); - } - } - } - } - - aggregated.fingerprints = all_fingerprints.into_values().collect(); - aggregated.total_unique_fingerprints = aggregated.fingerprints.len() as u64; - - Ok(aggregated) - } - - /// Collect TCP SYN statistics - fn collect_syn_stats(&self, skel: &FilterSkel) -> Result> { - let key = 0u32.to_le_bytes(); - let stats_bytes = skel.maps.tcp_syn_stats.lookup(&key, libbpf_rs::MapFlags::ANY) - .map_err(|e| format!("Failed to read TCP SYN stats: {}", e))?; - - if let Some(bytes) = stats_bytes { - if bytes.len() >= 24 { // 3 * u64 = 24 bytes - let total_syns = u64::from_le_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], - bytes[4], bytes[5], bytes[6], bytes[7], - ]); - let unique_fingerprints = u64::from_le_bytes([ - bytes[8], bytes[9], bytes[10], bytes[11], - bytes[12], bytes[13], bytes[14], bytes[15], - ]); - let _last_reset = u64::from_le_bytes([ - bytes[16], bytes[17], bytes[18], bytes[19], - bytes[20], bytes[21], bytes[22], bytes[23], - ]); - - Ok(TcpSynStats { - total_syns, - unique_fingerprints, - last_reset: Utc::now(), // Use current time as fallback - }) - } else { - Ok(TcpSynStats { - total_syns: 0, - unique_fingerprints: 0, - last_reset: Utc::now(), - }) - } - } else { - Ok(TcpSynStats { - total_syns: 0, - unique_fingerprints: 0, - last_reset: Utc::now(), - }) - } - } - - /// Collect TCP fingerprints from BPF map - fn collect_tcp_fingerprints(&self, skel: &FilterSkel, fingerprints: &mut Vec) -> Result<(), Box> { - log::debug!("Collecting TCP fingerprints from BPF map (IPv4)"); - - let mut count = 0; - let mut skipped_count = 0; - let mut batch_worked = false; - - // Helper closure to process a single entry - let mut process_entry = |key_bytes: &[u8], value_bytes: &[u8]| { - log::debug!("Processing IPv4 fingerprint entry: key_len={}, value_len={}", key_bytes.len(), value_bytes.len()); - - if key_bytes.len() >= 20 && value_bytes.len() >= 48 { - let src_ip = Ipv4Addr::from([key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3]]); - let src_port = u16::from_le_bytes([key_bytes[4], key_bytes[5]]); - let fingerprint = String::from_utf8_lossy(&key_bytes[6..20]).trim_end_matches('\0').to_string(); - - let packet_count = u32::from_le_bytes([ - value_bytes[16], value_bytes[17], value_bytes[18], value_bytes[19], - ]); - let ttl = u16::from_le_bytes([value_bytes[20], value_bytes[21]]); - let mss = u16::from_le_bytes([value_bytes[22], value_bytes[23]]); - let window_size = u16::from_le_bytes([value_bytes[24], value_bytes[25]]); - let window_scale = value_bytes[26]; - let options_len = value_bytes[27].min(16); - - if packet_count > 0 && packet_count >= self.config.min_packet_count { - let options = value_bytes[32..32 + options_len as usize].to_vec(); - - let entry = TcpFingerprintEntry { - key: TcpFingerprintKey { - src_ip: src_ip.to_string(), - src_port, - fingerprint: fingerprint.clone(), - }, - data: TcpFingerprintData { - first_seen: Utc::now(), - last_seen: Utc::now(), - packet_count, - ttl, - mss, - window_size, - window_scale, - options_len, - options, - }, - }; - - log::debug!("TCP Fingerprint: {}:{} - TTL:{} MSS:{} Window:{} Scale:{} Packets:{} Fingerprint:{}", - src_ip, src_port, ttl, mss, window_size, window_scale, packet_count, fingerprint); - - fingerprints.push(entry); - return (true, false); // (processed, skipped) - } else { - log::debug!("Skipping fingerprint entry with packet_count={} (threshold={}): {}:{}", - packet_count, self.config.min_packet_count, src_ip, src_port); - return (false, true); - } - } else { - log::debug!("Skipping fingerprint entry with invalid size: key_len={}, value_len={}", key_bytes.len(), value_bytes.len()); - return (false, true); - } - }; - - // Try batch lookup first - if let Ok(batch_iter) = skel.maps.tcp_fingerprints.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { - for (key_bytes, value_bytes) in batch_iter { - batch_worked = true; - let (processed, skipped) = process_entry(&key_bytes, &value_bytes); - if processed { count += 1; } - if skipped { skipped_count += 1; } - } - } - - // If batch lookup returned nothing, try keys iterator as fallback - if !batch_worked { - log::debug!("Batch lookup empty, trying keys iterator for IPv4 TCP fingerprints"); - for key_bytes in skel.maps.tcp_fingerprints.keys() { - if key_bytes.len() >= 20 { - if let Ok(Some(value_bytes)) = skel.maps.tcp_fingerprints.lookup(&key_bytes, libbpf_rs::MapFlags::ANY) { - let (processed, skipped) = process_entry(&key_bytes, &value_bytes); - if processed { count += 1; } - if skipped { skipped_count += 1; } - } - } - } - } - - log::debug!("Found {} IPv4 TCP fingerprints, skipped {} entries", count, skipped_count); - - // Collect IPv6 fingerprints - log::debug!("Collecting TCP fingerprints from BPF map (IPv6)"); - - match skel.maps.tcp_fingerprints_v6.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { - Ok(batch_iter) => { - let mut count = 0; - let mut skipped_count = 0; - for (key_bytes, value_bytes) in batch_iter { - log::debug!("Processing IPv6 fingerprint entry: key_len={}, value_len={}", key_bytes.len(), value_bytes.len()); - - if key_bytes.len() >= 32 && value_bytes.len() >= 48 { // Key: 16+2+14, Value: same as IPv4 - // Parse IPv6 address (16 bytes) - let src_ip: std::net::Ipv6Addr = std::net::Ipv6Addr::from([ - key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3], - key_bytes[4], key_bytes[5], key_bytes[6], key_bytes[7], - key_bytes[8], key_bytes[9], key_bytes[10], key_bytes[11], - key_bytes[12], key_bytes[13], key_bytes[14], key_bytes[15] - ]); - let src_port = u16::from_le_bytes([key_bytes[16], key_bytes[17]]); - let fingerprint = String::from_utf8_lossy(&key_bytes[18..32]).trim_end_matches('\0').to_string(); - - // Parse fingerprint data (same structure as IPv4) - let _first_seen = u64::from_le_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], - ]); - let _last_seen = u64::from_le_bytes([ - value_bytes[8], value_bytes[9], value_bytes[10], value_bytes[11], - value_bytes[12], value_bytes[13], value_bytes[14], value_bytes[15], - ]); - let packet_count = u32::from_le_bytes([ - value_bytes[16], value_bytes[17], value_bytes[18], value_bytes[19], - ]); - let ttl = u16::from_le_bytes([value_bytes[20], value_bytes[21]]); - let mss = u16::from_le_bytes([value_bytes[22], value_bytes[23]]); - let window_size = u16::from_le_bytes([value_bytes[24], value_bytes[25]]); - let window_scale = value_bytes[26]; - let options_len = value_bytes[27].min(16); // Cap at TCP_FP_MAX_OPTION_LEN - - // Only process entries with packet_count > 0 and above threshold - if packet_count > 0 && packet_count >= self.config.min_packet_count { - let options = value_bytes[32..32 + options_len as usize].to_vec(); - - let entry = TcpFingerprintEntry { - key: TcpFingerprintKey { - src_ip: src_ip.to_string(), - src_port, - fingerprint: fingerprint.clone(), - }, - data: TcpFingerprintData { - first_seen: Utc::now(), // Use current time as fallback - last_seen: Utc::now(), // Use current time as fallback - packet_count, - ttl, - mss, - window_size, - window_scale, - options_len, - options, - }, - }; - - // Log new IPv6 TCP fingerprint at debug level - log::debug!("TCP Fingerprint (IPv6): {}:{} - TTL:{} MSS:{} Window:{} Scale:{} Packets:{} Fingerprint:{}", - src_ip, src_port, ttl, mss, window_size, window_scale, packet_count, fingerprint); - - fingerprints.push(entry); - count += 1; - } else { - log::debug!("Skipping IPv6 fingerprint entry with packet_count={} (threshold={}): {}:{}", - packet_count, self.config.min_packet_count, src_ip, src_port); - skipped_count += 1; - } - } else { - log::debug!("Skipping IPv6 fingerprint entry with invalid size: key_len={}, value_len={}", key_bytes.len(), value_bytes.len()); - skipped_count += 1; - } - } - log::debug!("Found {} IPv6 TCP fingerprints, skipped {} entries", count, skipped_count); - } - Err(e) => { - log::warn!("Failed to read IPv6 TCP fingerprints: {}", e); - } - } - - Ok(()) - } - - - /// Log current TCP fingerprint statistics - pub fn log_stats(&self) -> Result<(), Box> { - if !self.enabled { - return Ok(()); - } - - match self.collect_aggregated_stats() { - Ok(stats) => { - log::debug!("{}", stats.summary()); - - // Log detailed unique fingerprint information - if stats.total_unique_fingerprints > 0 { - // Group fingerprints by fingerprint string to show unique patterns - let mut fingerprint_groups: std::collections::HashMap> = std::collections::HashMap::new(); - for entry in &stats.fingerprints { - fingerprint_groups.entry(entry.key.fingerprint.clone()).or_insert_with(Vec::new).push(entry); - } - - log::debug!("Unique fingerprint patterns: {} different patterns found", fingerprint_groups.len()); - - // Show top unique fingerprint patterns by total packet count - let mut pattern_stats: Vec<_> = fingerprint_groups.iter().map(|(pattern, entries)| { - let total_packets: u32 = entries.iter().map(|e| e.data.packet_count).sum(); - let unique_ips: std::collections::HashSet<_> = entries.iter().map(|e| &e.key.src_ip).collect(); - (pattern, total_packets, unique_ips.len(), entries.len()) - }).collect(); - - pattern_stats.sort_by(|a, b| b.1.cmp(&a.1)); // Sort by packet count - - log::debug!("Top unique fingerprint patterns:"); - for (i, (pattern, total_packets, unique_ips, entries)) in pattern_stats.iter().take(10).enumerate() { - log::debug!(" {}: {} ({} packets, {} unique IPs, {} entries)", - i + 1, pattern, total_packets, unique_ips, entries); - } - - // Log as JSON for structured logging - if let Ok(json) = serde_json::to_string(&stats) { - log::debug!("TCP Fingerprint Stats JSON: {}", json); - } - } else { - log::debug!("No unique fingerprints found"); - } - - Ok(()) - } - Err(e) => { - log::warn!("Failed to collect TCP fingerprint statistics: {}", e); - Err(e) - } - } - } - - /// Get unique fingerprint statistics - pub fn get_unique_fingerprint_stats(&self) -> Result> { - if !self.enabled { - return Ok(UniqueFingerprintStats { - timestamp: Utc::now(), - total_unique_patterns: 0, - total_unique_ips: 0, - total_packets: 0, - patterns: Vec::new(), - }); - } - - let stats = self.collect_aggregated_stats()?; - - // Group fingerprints by pattern - let mut fingerprint_groups: std::collections::HashMap> = std::collections::HashMap::new(); - for entry in &stats.fingerprints { - fingerprint_groups.entry(entry.key.fingerprint.clone()).or_insert_with(Vec::new).push(entry); - } - - let mut patterns = Vec::new(); - let mut total_unique_ips: std::collections::HashSet<&String> = std::collections::HashSet::new(); - let mut total_packets = 0u32; - - for (pattern, entries) in fingerprint_groups { - let pattern_packets: u32 = entries.iter().map(|e| e.data.packet_count).sum(); - let pattern_ips: std::collections::HashSet<_> = entries.iter().map(|e| &e.key.src_ip).collect(); - - total_unique_ips.extend(pattern_ips.iter()); - total_packets += pattern_packets; - - patterns.push(UniqueFingerprintPattern { - pattern: pattern.clone(), - packet_count: pattern_packets, - unique_ips: pattern_ips.len(), - entries: entries.len(), - }); - } - - patterns.sort_by(|a, b| b.packet_count.cmp(&a.packet_count)); - - Ok(UniqueFingerprintStats { - timestamp: Utc::now(), - total_unique_patterns: patterns.len(), - total_unique_ips: total_unique_ips.len(), - total_packets, - patterns, - }) - } - - /// Log unique fingerprint statistics - pub fn log_unique_fingerprint_stats(&self) -> Result<(), Box> { - if !self.enabled { - return Ok(()); - } - - match self.get_unique_fingerprint_stats() { - Ok(stats) => { - log::debug!("{}", stats.summary()); - - if stats.total_unique_patterns > 0 { - log::debug!("Top unique fingerprint patterns:"); - for (i, pattern) in stats.patterns.iter().take(10).enumerate() { - log::debug!(" {}: {} ({} packets, {} unique IPs, {} entries)", - i + 1, pattern.pattern, pattern.packet_count, pattern.unique_ips, pattern.entries); - } - - // Log as JSON for structured logging - if let Ok(json) = stats.to_json() { - log::debug!("Unique Fingerprint Stats JSON: {}", json); - } - } else { - log::debug!("No unique fingerprint patterns found"); - } - - Ok(()) - } - Err(e) => { - log::warn!("Failed to collect unique fingerprint statistics: {}", e); - Err(e) - } - } - } - - /// Collect TCP fingerprint events from all BPF skeletons - pub fn collect_fingerprint_events(&self) -> Result> { - if !self.enabled { - return Ok(TcpFingerprintEvents { - events: Vec::new(), - total_events: 0, - unique_ips: 0, - }); - } - - let mut all_events = Vec::new(); - let mut unique_ips = std::collections::HashSet::new(); - - for skel in &self.skels { - let mut fingerprints = Vec::new(); - self.collect_tcp_fingerprints(skel, &mut fingerprints)?; - - // Convert to events - for entry in fingerprints { - let event = TcpFingerprintEvent { - event_type: "tcp_fingerprint".to_string(), - timestamp: Utc::now(), - src_ip: entry.key.src_ip.clone(), - src_port: entry.key.src_port, - fingerprint: entry.key.fingerprint.clone(), - ttl: entry.data.ttl, - mss: entry.data.mss, - window_size: entry.data.window_size, - window_scale: entry.data.window_scale, - packet_count: entry.data.packet_count - }; - - unique_ips.insert(event.src_ip.clone()); - all_events.push(event); - } - } - - let total_events = all_events.len() as u64; - let unique_ips_count = unique_ips.len() as u64; - - Ok(TcpFingerprintEvents { - events: all_events, - total_events, - unique_ips: unique_ips_count, - }) - } - - /// Log TCP fingerprint events - pub fn log_fingerprint_events(&self) -> Result<(), Box> { - if !self.enabled { - return Ok(()); - } - - let events = self.collect_fingerprint_events()?; - - if events.total_events > 0 { - log::debug!("{}", events.summary()); - - // Log top 5 fingerprints - let top_fingerprints = events.get_top_fingerprints(5); - for event in top_fingerprints { - log::debug!(" {}", event.summary()); - } - - // Log as JSON for structured logging - if let Ok(json) = events.to_json() { - log::debug!("TCP Fingerprint Events JSON: {}", json); - } - - // Send events to unified queue - for event in events.events { - send_event(UnifiedEvent::TcpFingerprint(event)); - } - - // Reset the counters after logging - self.reset_fingerprint_counters()?; - } else { - log::debug!("No TCP fingerprint events found"); - } - - Ok(()) - } - - - /// Reset TCP fingerprint counters in BPF maps - pub fn reset_fingerprint_counters(&self) -> Result<(), Box> { - if !self.enabled { - return Ok(()); - } - - log::debug!("Resetting TCP fingerprint counters"); - - for skel in &self.skels { - // Reset TCP fingerprints map - match skel.maps.tcp_fingerprints.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { - Ok(batch_iter) => { - let mut reset_count = 0; - for (key_bytes, _) in batch_iter { - if key_bytes.len() >= 20 { - // Create zero value for fingerprint data (48 bytes with padding) - let zero_value = vec![0u8; 48]; - if let Err(e) = skel.maps.tcp_fingerprints.update(&key_bytes, &zero_value, libbpf_rs::MapFlags::ANY) { - log::warn!("Failed to reset TCP fingerprint counter: {}", e); - } else { - reset_count += 1; - } - } - } - log::debug!("Reset {} TCP fingerprint counters", reset_count); - } - Err(e) => { - log::warn!("Failed to reset TCP fingerprint counters: {}", e); - } - } - - // Reset TCP SYN stats - let key = 0u32.to_le_bytes(); - let zero_stats = vec![0u8; 24]; // 3 * u64 = 24 bytes - if let Err(e) = skel.maps.tcp_syn_stats.update(&key, &zero_stats, libbpf_rs::MapFlags::ANY) { - log::warn!("Failed to reset TCP SYN stats: {}", e); - } else { - log::debug!("Reset TCP SYN stats"); - } - } - - Ok(()) - } - - /// Check if BPF maps are accessible - pub fn check_maps_accessible(&self) -> Result<(), Box> { - if !self.enabled { - return Ok(()); - } - - for (i, skel) in self.skels.iter().enumerate() { - log::debug!("Checking accessibility of BPF maps for skeleton {}", i); - - // Check tcp_fingerprints map - match skel.maps.tcp_fingerprints.lookup_batch(1, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { - Ok(_) => log::debug!("tcp_fingerprints map is accessible for skeleton {}", i), - Err(e) => log::warn!("tcp_fingerprints map not accessible for skeleton {}: {}", i, e), - } - - // Check tcp_syn_stats map - let key = 0u32.to_le_bytes(); - match skel.maps.tcp_syn_stats.lookup(&key, libbpf_rs::MapFlags::ANY) { - Ok(_) => log::debug!("tcp_syn_stats map is accessible for skeleton {}", i), - Err(e) => log::warn!("tcp_syn_stats map not accessible for skeleton {}: {}", i, e), - } - } - - Ok(()) - } -} - -/// Configuration for TCP fingerprint collection -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TcpFingerprintCollectorConfig { - pub enabled: bool, - pub log_interval_secs: u64, - pub fingerprint_events_interval_secs: u64, -} - -impl Default for TcpFingerprintCollectorConfig { - fn default() -> Self { - Self { - enabled: true, - log_interval_secs: 60, // Log stats every minute - fingerprint_events_interval_secs: 30, // Send events every 30 seconds - } - } -} - -impl TcpFingerprintCollectorConfig { - /// Create a new configuration - pub fn new(enabled: bool, log_interval_secs: u64, fingerprint_events_interval_secs: u64) -> Self { - Self { - enabled, - log_interval_secs, - fingerprint_events_interval_secs, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_tcp_fingerprint_config_default() { - let config = TcpFingerprintConfig::default(); - assert!(config.enabled); - assert_eq!(config.log_interval_secs, 60); - assert!(config.enable_fingerprint_events); - assert_eq!(config.fingerprint_events_interval_secs, 30); - } - - #[test] - fn test_tcp_fingerprint_collector_config_default() { - let config = TcpFingerprintCollectorConfig::default(); - assert!(config.enabled); - assert_eq!(config.log_interval_secs, 60); - assert_eq!(config.fingerprint_events_interval_secs, 30); - } - - #[test] - fn test_tcp_fingerprint_event_summary() { - let event = TcpFingerprintEvent { - event_type: "tcp_fingerprint".to_string(), - timestamp: Utc::now(), - src_ip: "192.168.1.1".to_string(), - src_port: 80, - fingerprint: "64:1460:65535:7".to_string(), - ttl: 64, - mss: 1460, - window_size: 65535, - window_scale: 7, - packet_count: 1 - }; - - let summary = event.summary(); - assert!(summary.contains("192.168.1.1")); - assert!(summary.contains("80")); - assert!(summary.contains("64:1460:65535:7")); - } - - #[test] - fn test_unique_fingerprint_stats() { - let stats = UniqueFingerprintStats { - timestamp: Utc::now(), - total_unique_patterns: 2, - total_unique_ips: 3, - total_packets: 100, - patterns: vec![ - UniqueFingerprintPattern { - pattern: "64:1460:65535:7".to_string(), - packet_count: 60, - unique_ips: 2, - entries: 2, - }, - UniqueFingerprintPattern { - pattern: "128:1460:32768:8".to_string(), - packet_count: 40, - unique_ips: 1, - entries: 1, - }, - ], - }; - - let summary = stats.summary(); - assert!(summary.contains("2 patterns")); - assert!(summary.contains("3 unique IPs")); - assert!(summary.contains("100 total packets")); - - let json = stats.to_json().unwrap(); - assert!(json.contains("total_unique_patterns")); - assert!(json.contains("patterns")); - } -} +use std::sync::Arc; +use serde::{Deserialize, Serialize}; +use chrono::{DateTime, Utc}; +use std::net::Ipv4Addr; +use libbpf_rs::MapCore; +use crate::worker::log::{send_event, UnifiedEvent}; + +use crate::bpf::FilterSkel; + +/// TCP fingerprinting configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintConfig { + pub enabled: bool, + pub log_interval_secs: u64, + pub enable_fingerprint_events: bool, + pub fingerprint_events_interval_secs: u64, + pub min_packet_count: u32, + pub min_connection_duration_secs: u64, +} + +impl Default for TcpFingerprintConfig { + fn default() -> Self { + Self { + enabled: true, + log_interval_secs: 60, + enable_fingerprint_events: true, + fingerprint_events_interval_secs: 30, + min_packet_count: 3, + min_connection_duration_secs: 1, + } + } +} + +impl TcpFingerprintConfig { + /// Convert from CLI configuration + pub fn from_cli_config(cli_config: &crate::cli::TcpFingerprintConfig) -> Self { + Self { + enabled: cli_config.enabled, + log_interval_secs: cli_config.log_interval_secs, + enable_fingerprint_events: cli_config.enable_fingerprint_events, + fingerprint_events_interval_secs: cli_config.fingerprint_events_interval_secs, + min_packet_count: cli_config.min_packet_count, + min_connection_duration_secs: cli_config.min_connection_duration_secs, + } + } +} + +/// TCP fingerprint data collected from BPF +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintData { + pub first_seen: DateTime, + pub last_seen: DateTime, + pub packet_count: u32, + pub ttl: u16, + pub mss: u16, + pub window_size: u16, + pub window_scale: u8, + pub options_len: u8, + pub options: Vec, +} + +/// TCP fingerprint key +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintKey { + pub src_ip: String, + pub src_port: u16, + pub fingerprint: String, +} + +/// TCP fingerprint entry +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintEntry { + pub key: TcpFingerprintKey, + pub data: TcpFingerprintData, +} + +/// TCP SYN statistics +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpSynStats { + pub total_syns: u64, + pub unique_fingerprints: u64, + pub last_reset: DateTime, +} + +/// TCP fingerprinting statistics +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintStats { + pub timestamp: DateTime, + pub syn_stats: TcpSynStats, + pub fingerprints: Vec, + pub total_unique_fingerprints: u64, +} + +/// TCP fingerprint event for API +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintEvent { + pub event_type: String, + pub timestamp: DateTime, + pub src_ip: String, + pub src_port: u16, + pub fingerprint: String, + pub ttl: u16, + pub mss: u16, + pub window_size: u16, + pub window_scale: u8, + pub packet_count: u32 +} + +/// Collection of TCP fingerprint events +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintEvents { + pub events: Vec, + pub total_events: u64, + pub unique_ips: u64, +} + +/// Unique fingerprint pattern statistics +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UniqueFingerprintPattern { + pub pattern: String, + pub packet_count: u32, + pub unique_ips: usize, + pub entries: usize, +} + +/// Unique fingerprint statistics +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UniqueFingerprintStats { + pub timestamp: DateTime, + pub total_unique_patterns: usize, + pub total_unique_ips: usize, + pub total_packets: u32, + pub patterns: Vec, +} + +impl TcpFingerprintEvents { + /// Get top fingerprints by packet count + pub fn get_top_fingerprints(&self, limit: usize) -> Vec { + let mut events = self.events.clone(); + events.sort_by(|a, b| b.packet_count.cmp(&a.packet_count)); + events.into_iter().take(limit).collect() + } + + /// Convert to JSON string + pub fn to_json(&self) -> Result { + serde_json::to_string(self) + } + + /// Generate summary string + pub fn summary(&self) -> String { + format!("TCP Fingerprint Events: {} events from {} unique IPs", + self.total_events, self.unique_ips) + } +} + +impl TcpFingerprintEvent { + /// Generate summary string + pub fn summary(&self) -> String { + format!("TCP Fingerprint: {}:{} {} (TTL:{}, MSS:{}, Window:{}, Scale:{}, Packets:{})", + self.src_ip, self.src_port, self.fingerprint, + self.ttl, self.mss, self.window_size, self.window_scale, self.packet_count) + } +} + +impl UniqueFingerprintStats { + /// Generate summary string + pub fn summary(&self) -> String { + format!("Unique Fingerprint Stats: {} patterns, {} unique IPs, {} total packets", + self.total_unique_patterns, self.total_unique_ips, self.total_packets) + } + + /// Convert to JSON string + pub fn to_json(&self) -> Result { + serde_json::to_string(self) + } +} + +impl TcpFingerprintStats { + /// Generate summary string + pub fn summary(&self) -> String { + let mut summary = format!("TCP Fingerprint Stats: {} SYN packets, {} unique fingerprints, {} total entries", + self.syn_stats.total_syns, self.syn_stats.unique_fingerprints, self.total_unique_fingerprints); + + // Add top unique fingerprints if any + if !self.fingerprints.is_empty() { + summary.push_str(&format!(", {} unique fingerprints found", self.fingerprints.len())); + + // Show top 5 fingerprints by packet count + let mut fingerprint_vec: Vec<_> = self.fingerprints.iter().collect(); + fingerprint_vec.sort_by(|a, b| b.data.packet_count.cmp(&a.data.packet_count)); + + if !fingerprint_vec.is_empty() { + summary.push_str(", Top fingerprints: "); + for (i, entry) in fingerprint_vec.iter().take(5).enumerate() { + if i > 0 { summary.push_str(", "); } + summary.push_str(&format!("{}:{}:{}:{}", + entry.key.src_ip, entry.key.src_port, entry.key.fingerprint, entry.data.packet_count)); + } + } + } + + summary + } +} + +/// Global TCP fingerprint collector +static TCP_FINGERPRINT_COLLECTOR: std::sync::OnceLock> = std::sync::OnceLock::new(); + +/// Set the global TCP fingerprint collector +pub fn set_global_tcp_fingerprint_collector(collector: TcpFingerprintCollector) { + let _ = TCP_FINGERPRINT_COLLECTOR.set(Arc::new(collector)); +} + +/// Get the global TCP fingerprint collector +pub fn get_global_tcp_fingerprint_collector() -> Option> { + TCP_FINGERPRINT_COLLECTOR.get().cloned() +} + +/// TCP fingerprint collector +#[derive(Clone)] +pub struct TcpFingerprintCollector { + skels: Vec>>, + enabled: bool, + config: TcpFingerprintConfig, +} + +impl TcpFingerprintCollector { + /// Create a new TCP fingerprint collector + pub fn new(skels: Vec>>, enabled: bool) -> Self { + Self { + skels, + enabled, + config: TcpFingerprintConfig::default(), + } + } + + /// Create a new TCP fingerprint collector with configuration + pub fn new_with_config(skels: Vec>>, config: TcpFingerprintConfig) -> Self { + Self { + skels, + enabled: config.enabled, + config, + } + } + + /// Enable or disable fingerprint collection + pub fn set_enabled(&mut self, enabled: bool) { + self.enabled = enabled; + } + + /// Check if fingerprint collection is enabled + pub fn is_enabled(&self) -> bool { + self.enabled + } + + /// Lookup TCP fingerprint for a specific source IP and port + pub fn lookup_fingerprint(&self, src_ip: std::net::IpAddr, src_port: u16) -> Option { + if !self.enabled || self.skels.is_empty() { + return None; + } + + match src_ip { + std::net::IpAddr::V4(ip) => { + let octets = ip.octets(); + let src_ip_be = u32::from_be_bytes(octets); + + // Try to find fingerprint in any skeleton's IPv4 map + for skel in &self.skels { + if let Ok(iter) = skel.maps.tcp_fingerprints.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + for (key_bytes, value_bytes) in iter { + if key_bytes.len() >= 6 && value_bytes.len() >= 32 { + // Parse key structure: src_ip (4 bytes BE), src_port (2 bytes BE), fingerprint (14 bytes) + // BPF stores IP as __be32 (big-endian), so read as big-endian + let key_ip = u32::from_be_bytes([key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3]]); + let key_port = u16::from_be_bytes([key_bytes[4], key_bytes[5]]); + + if key_ip == src_ip_be && key_port == src_port { + // Parse value structure + if value_bytes.len() >= 32 { + let first_seen = u64::from_ne_bytes([ + value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], + value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7] + ]); + let last_seen = u64::from_ne_bytes([ + value_bytes[8], value_bytes[9], value_bytes[10], value_bytes[11], + value_bytes[12], value_bytes[13], value_bytes[14], value_bytes[15] + ]); + let packet_count = u32::from_ne_bytes([ + value_bytes[16], value_bytes[17], value_bytes[18], value_bytes[19] + ]); + let ttl = u16::from_ne_bytes([value_bytes[20], value_bytes[21]]); + let mss = u16::from_ne_bytes([value_bytes[22], value_bytes[23]]); + let window_size = u16::from_ne_bytes([value_bytes[24], value_bytes[25]]); + let window_scale = value_bytes[26]; + let options_len = value_bytes[27]; + + let options_size = options_len.min(16) as usize; + let mut options = vec![0u8; options_size]; + if value_bytes.len() >= 28 + options_size { + options.copy_from_slice(&value_bytes[28..28 + options_size]); + } + + return Some(TcpFingerprintData { + first_seen: DateTime::from_timestamp_nanos(first_seen as i64), + last_seen: DateTime::from_timestamp_nanos(last_seen as i64), + packet_count, + ttl, + mss, + window_size, + window_scale, + options_len, + options, + }); + } + } + } + } + } + } + None + } + std::net::IpAddr::V6(ip) => { + let octets = ip.octets(); + + // Try to find fingerprint in any skeleton's IPv6 map + for skel in &self.skels { + if let Ok(iter) = skel.maps.tcp_fingerprints_v6.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + for (key_bytes, value_bytes) in iter { + if key_bytes.len() >= 18 && value_bytes.len() >= 32 { + // Parse key structure: src_ip (16 bytes), src_port (2 bytes BE), fingerprint (14 bytes) + let mut key_ip = [0u8; 16]; + key_ip.copy_from_slice(&key_bytes[0..16]); + let key_port = u16::from_be_bytes([key_bytes[16], key_bytes[17]]); + + if key_ip == octets && key_port == src_port { + // Parse value structure (same as IPv4) + if value_bytes.len() >= 32 { + let first_seen = u64::from_ne_bytes([ + value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], + value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7] + ]); + let last_seen = u64::from_ne_bytes([ + value_bytes[8], value_bytes[9], value_bytes[10], value_bytes[11], + value_bytes[12], value_bytes[13], value_bytes[14], value_bytes[15] + ]); + let packet_count = u32::from_ne_bytes([ + value_bytes[16], value_bytes[17], value_bytes[18], value_bytes[19] + ]); + let ttl = u16::from_ne_bytes([value_bytes[20], value_bytes[21]]); + let mss = u16::from_ne_bytes([value_bytes[22], value_bytes[23]]); + let window_size = u16::from_ne_bytes([value_bytes[24], value_bytes[25]]); + let window_scale = value_bytes[26]; + let options_len = value_bytes[27]; + + let options_size = options_len.min(16) as usize; + let mut options = vec![0u8; options_size]; + if value_bytes.len() >= 28 + options_size { + options.copy_from_slice(&value_bytes[28..28 + options_size]); + } + + return Some(TcpFingerprintData { + first_seen: DateTime::from_timestamp_nanos(first_seen as i64), + last_seen: DateTime::from_timestamp_nanos(last_seen as i64), + packet_count, + ttl, + mss, + window_size, + window_scale, + options_len, + options, + }); + } + } + } + } + } + } + None + } + } + } + + /// Collect TCP fingerprint statistics from all BPF skeletons + pub fn collect_fingerprint_stats(&self) -> Result, Box> { + if !self.enabled { + return Ok(vec![]); + } + + let mut stats = Vec::new(); + for (i, skel) in self.skels.iter().enumerate() { + log::debug!("Collecting TCP fingerprint stats from skeleton {}", i); + match self.collect_fingerprint_stats_from_skeleton(skel) { + Ok(stat) => { + log::debug!("Skeleton {} collected {} fingerprints", i, stat.fingerprints.len()); + stats.push(stat); + } + Err(e) => { + log::warn!("Failed to collect TCP fingerprint stats from skeleton {}: {}", i, e); + } + } + } + log::debug!("Collected stats from {} skeletons", stats.len()); + Ok(stats) + } + + /// Collect TCP fingerprint statistics from a single BPF skeleton + fn collect_fingerprint_stats_from_skeleton(&self, skel: &FilterSkel) -> Result> { + if !self.enabled { + return Ok(TcpFingerprintStats { + timestamp: Utc::now(), + syn_stats: TcpSynStats { + total_syns: 0, + unique_fingerprints: 0, + last_reset: Utc::now(), + }, + fingerprints: Vec::new(), + total_unique_fingerprints: 0, + }); + } + + let mut fingerprints = Vec::new(); + + // Read TCP SYN statistics + log::debug!("Reading TCP SYN statistics from skeleton"); + let syn_stats = self.collect_syn_stats(skel)?; + log::debug!("TCP SYN stats: {} total_syns, {} unique_fingerprints", syn_stats.total_syns, syn_stats.unique_fingerprints); + + // Read TCP fingerprints from BPF map + log::debug!("Reading TCP fingerprints from skeleton"); + self.collect_tcp_fingerprints(skel, &mut fingerprints)?; + log::debug!("Collected {} fingerprints from skeleton", fingerprints.len()); + + let total_unique_fingerprints = fingerprints.len() as u64; + + Ok(TcpFingerprintStats { + timestamp: Utc::now(), + syn_stats, + fingerprints, + total_unique_fingerprints, + }) + } + + /// Collect aggregated TCP fingerprint statistics across all skeletons + pub fn collect_aggregated_stats(&self) -> Result> { + if !self.enabled { + return Err("TCP fingerprint collection is disabled".into()); + } + + log::debug!("Collecting aggregated TCP fingerprint statistics from {} skeletons", self.skels.len()); + let individual_stats = self.collect_fingerprint_stats()?; + log::debug!("Collected {} individual stats", individual_stats.len()); + + if individual_stats.is_empty() { + log::warn!("No TCP fingerprint statistics available from any skeleton"); + return Err("No TCP fingerprint statistics available".into()); + } + + // Aggregate statistics across all skeletons + let mut aggregated = TcpFingerprintStats { + timestamp: Utc::now(), + syn_stats: TcpSynStats { + total_syns: 0, + unique_fingerprints: 0, + last_reset: Utc::now(), + }, + fingerprints: Vec::new(), + total_unique_fingerprints: 0, + }; + + let mut all_fingerprints: std::collections::HashMap = std::collections::HashMap::new(); + + for stat in individual_stats { + aggregated.syn_stats.total_syns += stat.syn_stats.total_syns; + aggregated.syn_stats.unique_fingerprints += stat.syn_stats.unique_fingerprints; + + // Merge fingerprints by key (src_ip:src_port:fingerprint) + for entry in stat.fingerprints { + let key = format!("{}:{}:{}", entry.key.src_ip, entry.key.src_port, entry.key.fingerprint); + match all_fingerprints.get_mut(&key) { + Some(existing) => { + // Update packet count and timestamps + existing.data.packet_count += entry.data.packet_count; + if entry.data.first_seen < existing.data.first_seen { + existing.data.first_seen = entry.data.first_seen; + } + if entry.data.last_seen > existing.data.last_seen { + existing.data.last_seen = entry.data.last_seen; + } + } + None => { + all_fingerprints.insert(key, entry); + } + } + } + } + + aggregated.fingerprints = all_fingerprints.into_values().collect(); + aggregated.total_unique_fingerprints = aggregated.fingerprints.len() as u64; + + Ok(aggregated) + } + + /// Collect TCP SYN statistics + fn collect_syn_stats(&self, skel: &FilterSkel) -> Result> { + let key = 0u32.to_le_bytes(); + let stats_bytes = skel.maps.tcp_syn_stats.lookup(&key, libbpf_rs::MapFlags::ANY) + .map_err(|e| format!("Failed to read TCP SYN stats: {}", e))?; + + if let Some(bytes) = stats_bytes { + if bytes.len() >= 24 { // 3 * u64 = 24 bytes + let total_syns = u64::from_le_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], + bytes[4], bytes[5], bytes[6], bytes[7], + ]); + let unique_fingerprints = u64::from_le_bytes([ + bytes[8], bytes[9], bytes[10], bytes[11], + bytes[12], bytes[13], bytes[14], bytes[15], + ]); + let _last_reset = u64::from_le_bytes([ + bytes[16], bytes[17], bytes[18], bytes[19], + bytes[20], bytes[21], bytes[22], bytes[23], + ]); + + Ok(TcpSynStats { + total_syns, + unique_fingerprints, + last_reset: Utc::now(), // Use current time as fallback + }) + } else { + Ok(TcpSynStats { + total_syns: 0, + unique_fingerprints: 0, + last_reset: Utc::now(), + }) + } + } else { + Ok(TcpSynStats { + total_syns: 0, + unique_fingerprints: 0, + last_reset: Utc::now(), + }) + } + } + + /// Collect TCP fingerprints from BPF map + fn collect_tcp_fingerprints(&self, skel: &FilterSkel, fingerprints: &mut Vec) -> Result<(), Box> { + log::debug!("Collecting TCP fingerprints from BPF map (IPv4)"); + + let mut count = 0; + let mut skipped_count = 0; + let mut batch_worked = false; + + // Helper closure to process a single entry + let mut process_entry = |key_bytes: &[u8], value_bytes: &[u8]| { + log::debug!("Processing IPv4 fingerprint entry: key_len={}, value_len={}", key_bytes.len(), value_bytes.len()); + + if key_bytes.len() >= 20 && value_bytes.len() >= 48 { + let src_ip = Ipv4Addr::from([key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3]]); + let src_port = u16::from_le_bytes([key_bytes[4], key_bytes[5]]); + let fingerprint = String::from_utf8_lossy(&key_bytes[6..20]).trim_end_matches('\0').to_string(); + + let packet_count = u32::from_le_bytes([ + value_bytes[16], value_bytes[17], value_bytes[18], value_bytes[19], + ]); + let ttl = u16::from_le_bytes([value_bytes[20], value_bytes[21]]); + let mss = u16::from_le_bytes([value_bytes[22], value_bytes[23]]); + let window_size = u16::from_le_bytes([value_bytes[24], value_bytes[25]]); + let window_scale = value_bytes[26]; + let options_len = value_bytes[27].min(16); + + if packet_count > 0 && packet_count >= self.config.min_packet_count { + let options = value_bytes[32..32 + options_len as usize].to_vec(); + + let entry = TcpFingerprintEntry { + key: TcpFingerprintKey { + src_ip: src_ip.to_string(), + src_port, + fingerprint: fingerprint.clone(), + }, + data: TcpFingerprintData { + first_seen: Utc::now(), + last_seen: Utc::now(), + packet_count, + ttl, + mss, + window_size, + window_scale, + options_len, + options, + }, + }; + + log::debug!("TCP Fingerprint: {}:{} - TTL:{} MSS:{} Window:{} Scale:{} Packets:{} Fingerprint:{}", + src_ip, src_port, ttl, mss, window_size, window_scale, packet_count, fingerprint); + + fingerprints.push(entry); + return (true, false); // (processed, skipped) + } else { + log::debug!("Skipping fingerprint entry with packet_count={} (threshold={}): {}:{}", + packet_count, self.config.min_packet_count, src_ip, src_port); + return (false, true); + } + } else { + log::debug!("Skipping fingerprint entry with invalid size: key_len={}, value_len={}", key_bytes.len(), value_bytes.len()); + return (false, true); + } + }; + + // Try batch lookup first + if let Ok(batch_iter) = skel.maps.tcp_fingerprints.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + for (key_bytes, value_bytes) in batch_iter { + batch_worked = true; + let (processed, skipped) = process_entry(&key_bytes, &value_bytes); + if processed { count += 1; } + if skipped { skipped_count += 1; } + } + } + + // If batch lookup returned nothing, try keys iterator as fallback + if !batch_worked { + log::debug!("Batch lookup empty, trying keys iterator for IPv4 TCP fingerprints"); + for key_bytes in skel.maps.tcp_fingerprints.keys() { + if key_bytes.len() >= 20 { + if let Ok(Some(value_bytes)) = skel.maps.tcp_fingerprints.lookup(&key_bytes, libbpf_rs::MapFlags::ANY) { + let (processed, skipped) = process_entry(&key_bytes, &value_bytes); + if processed { count += 1; } + if skipped { skipped_count += 1; } + } + } + } + } + + log::debug!("Found {} IPv4 TCP fingerprints, skipped {} entries", count, skipped_count); + + // Collect IPv6 fingerprints + log::debug!("Collecting TCP fingerprints from BPF map (IPv6)"); + + match skel.maps.tcp_fingerprints_v6.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + Ok(batch_iter) => { + let mut count = 0; + let mut skipped_count = 0; + for (key_bytes, value_bytes) in batch_iter { + log::debug!("Processing IPv6 fingerprint entry: key_len={}, value_len={}", key_bytes.len(), value_bytes.len()); + + if key_bytes.len() >= 32 && value_bytes.len() >= 48 { // Key: 16+2+14, Value: same as IPv4 + // Parse IPv6 address (16 bytes) + let src_ip: std::net::Ipv6Addr = std::net::Ipv6Addr::from([ + key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3], + key_bytes[4], key_bytes[5], key_bytes[6], key_bytes[7], + key_bytes[8], key_bytes[9], key_bytes[10], key_bytes[11], + key_bytes[12], key_bytes[13], key_bytes[14], key_bytes[15] + ]); + let src_port = u16::from_le_bytes([key_bytes[16], key_bytes[17]]); + let fingerprint = String::from_utf8_lossy(&key_bytes[18..32]).trim_end_matches('\0').to_string(); + + // Parse fingerprint data (same structure as IPv4) + let _first_seen = u64::from_le_bytes([ + value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], + value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], + ]); + let _last_seen = u64::from_le_bytes([ + value_bytes[8], value_bytes[9], value_bytes[10], value_bytes[11], + value_bytes[12], value_bytes[13], value_bytes[14], value_bytes[15], + ]); + let packet_count = u32::from_le_bytes([ + value_bytes[16], value_bytes[17], value_bytes[18], value_bytes[19], + ]); + let ttl = u16::from_le_bytes([value_bytes[20], value_bytes[21]]); + let mss = u16::from_le_bytes([value_bytes[22], value_bytes[23]]); + let window_size = u16::from_le_bytes([value_bytes[24], value_bytes[25]]); + let window_scale = value_bytes[26]; + let options_len = value_bytes[27].min(16); // Cap at TCP_FP_MAX_OPTION_LEN + + // Only process entries with packet_count > 0 and above threshold + if packet_count > 0 && packet_count >= self.config.min_packet_count { + let options = value_bytes[32..32 + options_len as usize].to_vec(); + + let entry = TcpFingerprintEntry { + key: TcpFingerprintKey { + src_ip: src_ip.to_string(), + src_port, + fingerprint: fingerprint.clone(), + }, + data: TcpFingerprintData { + first_seen: Utc::now(), // Use current time as fallback + last_seen: Utc::now(), // Use current time as fallback + packet_count, + ttl, + mss, + window_size, + window_scale, + options_len, + options, + }, + }; + + // Log new IPv6 TCP fingerprint at debug level + log::debug!("TCP Fingerprint (IPv6): {}:{} - TTL:{} MSS:{} Window:{} Scale:{} Packets:{} Fingerprint:{}", + src_ip, src_port, ttl, mss, window_size, window_scale, packet_count, fingerprint); + + fingerprints.push(entry); + count += 1; + } else { + log::debug!("Skipping IPv6 fingerprint entry with packet_count={} (threshold={}): {}:{}", + packet_count, self.config.min_packet_count, src_ip, src_port); + skipped_count += 1; + } + } else { + log::debug!("Skipping IPv6 fingerprint entry with invalid size: key_len={}, value_len={}", key_bytes.len(), value_bytes.len()); + skipped_count += 1; + } + } + log::debug!("Found {} IPv6 TCP fingerprints, skipped {} entries", count, skipped_count); + } + Err(e) => { + log::warn!("Failed to read IPv6 TCP fingerprints: {}", e); + } + } + + Ok(()) + } + + + /// Log current TCP fingerprint statistics + pub fn log_stats(&self) -> Result<(), Box> { + if !self.enabled { + return Ok(()); + } + + match self.collect_aggregated_stats() { + Ok(stats) => { + log::debug!("{}", stats.summary()); + + // Log detailed unique fingerprint information + if stats.total_unique_fingerprints > 0 { + // Group fingerprints by fingerprint string to show unique patterns + let mut fingerprint_groups: std::collections::HashMap> = std::collections::HashMap::new(); + for entry in &stats.fingerprints { + fingerprint_groups.entry(entry.key.fingerprint.clone()).or_insert_with(Vec::new).push(entry); + } + + log::debug!("Unique fingerprint patterns: {} different patterns found", fingerprint_groups.len()); + + // Show top unique fingerprint patterns by total packet count + let mut pattern_stats: Vec<_> = fingerprint_groups.iter().map(|(pattern, entries)| { + let total_packets: u32 = entries.iter().map(|e| e.data.packet_count).sum(); + let unique_ips: std::collections::HashSet<_> = entries.iter().map(|e| &e.key.src_ip).collect(); + (pattern, total_packets, unique_ips.len(), entries.len()) + }).collect(); + + pattern_stats.sort_by(|a, b| b.1.cmp(&a.1)); // Sort by packet count + + log::debug!("Top unique fingerprint patterns:"); + for (i, (pattern, total_packets, unique_ips, entries)) in pattern_stats.iter().take(10).enumerate() { + log::debug!(" {}: {} ({} packets, {} unique IPs, {} entries)", + i + 1, pattern, total_packets, unique_ips, entries); + } + + // Log as JSON for structured logging + if let Ok(json) = serde_json::to_string(&stats) { + log::debug!("TCP Fingerprint Stats JSON: {}", json); + } + } else { + log::debug!("No unique fingerprints found"); + } + + Ok(()) + } + Err(e) => { + log::warn!("Failed to collect TCP fingerprint statistics: {}", e); + Err(e) + } + } + } + + /// Get unique fingerprint statistics + pub fn get_unique_fingerprint_stats(&self) -> Result> { + if !self.enabled { + return Ok(UniqueFingerprintStats { + timestamp: Utc::now(), + total_unique_patterns: 0, + total_unique_ips: 0, + total_packets: 0, + patterns: Vec::new(), + }); + } + + let stats = self.collect_aggregated_stats()?; + + // Group fingerprints by pattern + let mut fingerprint_groups: std::collections::HashMap> = std::collections::HashMap::new(); + for entry in &stats.fingerprints { + fingerprint_groups.entry(entry.key.fingerprint.clone()).or_insert_with(Vec::new).push(entry); + } + + let mut patterns = Vec::new(); + let mut total_unique_ips: std::collections::HashSet<&String> = std::collections::HashSet::new(); + let mut total_packets = 0u32; + + for (pattern, entries) in fingerprint_groups { + let pattern_packets: u32 = entries.iter().map(|e| e.data.packet_count).sum(); + let pattern_ips: std::collections::HashSet<_> = entries.iter().map(|e| &e.key.src_ip).collect(); + + total_unique_ips.extend(pattern_ips.iter()); + total_packets += pattern_packets; + + patterns.push(UniqueFingerprintPattern { + pattern: pattern.clone(), + packet_count: pattern_packets, + unique_ips: pattern_ips.len(), + entries: entries.len(), + }); + } + + patterns.sort_by(|a, b| b.packet_count.cmp(&a.packet_count)); + + Ok(UniqueFingerprintStats { + timestamp: Utc::now(), + total_unique_patterns: patterns.len(), + total_unique_ips: total_unique_ips.len(), + total_packets, + patterns, + }) + } + + /// Log unique fingerprint statistics + pub fn log_unique_fingerprint_stats(&self) -> Result<(), Box> { + if !self.enabled { + return Ok(()); + } + + match self.get_unique_fingerprint_stats() { + Ok(stats) => { + log::debug!("{}", stats.summary()); + + if stats.total_unique_patterns > 0 { + log::debug!("Top unique fingerprint patterns:"); + for (i, pattern) in stats.patterns.iter().take(10).enumerate() { + log::debug!(" {}: {} ({} packets, {} unique IPs, {} entries)", + i + 1, pattern.pattern, pattern.packet_count, pattern.unique_ips, pattern.entries); + } + + // Log as JSON for structured logging + if let Ok(json) = stats.to_json() { + log::debug!("Unique Fingerprint Stats JSON: {}", json); + } + } else { + log::debug!("No unique fingerprint patterns found"); + } + + Ok(()) + } + Err(e) => { + log::warn!("Failed to collect unique fingerprint statistics: {}", e); + Err(e) + } + } + } + + /// Collect TCP fingerprint events from all BPF skeletons + pub fn collect_fingerprint_events(&self) -> Result> { + if !self.enabled { + return Ok(TcpFingerprintEvents { + events: Vec::new(), + total_events: 0, + unique_ips: 0, + }); + } + + let mut all_events = Vec::new(); + let mut unique_ips = std::collections::HashSet::new(); + + for skel in &self.skels { + let mut fingerprints = Vec::new(); + self.collect_tcp_fingerprints(skel, &mut fingerprints)?; + + // Convert to events + for entry in fingerprints { + let event = TcpFingerprintEvent { + event_type: "tcp_fingerprint".to_string(), + timestamp: Utc::now(), + src_ip: entry.key.src_ip.clone(), + src_port: entry.key.src_port, + fingerprint: entry.key.fingerprint.clone(), + ttl: entry.data.ttl, + mss: entry.data.mss, + window_size: entry.data.window_size, + window_scale: entry.data.window_scale, + packet_count: entry.data.packet_count + }; + + unique_ips.insert(event.src_ip.clone()); + all_events.push(event); + } + } + + let total_events = all_events.len() as u64; + let unique_ips_count = unique_ips.len() as u64; + + Ok(TcpFingerprintEvents { + events: all_events, + total_events, + unique_ips: unique_ips_count, + }) + } + + /// Log TCP fingerprint events + pub fn log_fingerprint_events(&self) -> Result<(), Box> { + if !self.enabled { + return Ok(()); + } + + let events = self.collect_fingerprint_events()?; + + if events.total_events > 0 { + log::debug!("{}", events.summary()); + + // Log top 5 fingerprints + let top_fingerprints = events.get_top_fingerprints(5); + for event in top_fingerprints { + log::debug!(" {}", event.summary()); + } + + // Log as JSON for structured logging + if let Ok(json) = events.to_json() { + log::debug!("TCP Fingerprint Events JSON: {}", json); + } + + // Send events to unified queue + for event in events.events { + send_event(UnifiedEvent::TcpFingerprint(event)); + } + + // Reset the counters after logging + self.reset_fingerprint_counters()?; + } else { + log::debug!("No TCP fingerprint events found"); + } + + Ok(()) + } + + + /// Reset TCP fingerprint counters in BPF maps + pub fn reset_fingerprint_counters(&self) -> Result<(), Box> { + if !self.enabled { + return Ok(()); + } + + log::debug!("Resetting TCP fingerprint counters"); + + for skel in &self.skels { + // Reset TCP fingerprints map + match skel.maps.tcp_fingerprints.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + Ok(batch_iter) => { + let mut reset_count = 0; + for (key_bytes, _) in batch_iter { + if key_bytes.len() >= 20 { + // Create zero value for fingerprint data (48 bytes with padding) + let zero_value = vec![0u8; 48]; + if let Err(e) = skel.maps.tcp_fingerprints.update(&key_bytes, &zero_value, libbpf_rs::MapFlags::ANY) { + log::warn!("Failed to reset TCP fingerprint counter: {}", e); + } else { + reset_count += 1; + } + } + } + log::debug!("Reset {} TCP fingerprint counters", reset_count); + } + Err(e) => { + log::warn!("Failed to reset TCP fingerprint counters: {}", e); + } + } + + // Reset TCP SYN stats + let key = 0u32.to_le_bytes(); + let zero_stats = vec![0u8; 24]; // 3 * u64 = 24 bytes + if let Err(e) = skel.maps.tcp_syn_stats.update(&key, &zero_stats, libbpf_rs::MapFlags::ANY) { + log::warn!("Failed to reset TCP SYN stats: {}", e); + } else { + log::debug!("Reset TCP SYN stats"); + } + } + + Ok(()) + } + + /// Check if BPF maps are accessible + pub fn check_maps_accessible(&self) -> Result<(), Box> { + if !self.enabled { + return Ok(()); + } + + for (i, skel) in self.skels.iter().enumerate() { + log::debug!("Checking accessibility of BPF maps for skeleton {}", i); + + // Check tcp_fingerprints map + match skel.maps.tcp_fingerprints.lookup_batch(1, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + Ok(_) => log::debug!("tcp_fingerprints map is accessible for skeleton {}", i), + Err(e) => log::warn!("tcp_fingerprints map not accessible for skeleton {}: {}", i, e), + } + + // Check tcp_syn_stats map + let key = 0u32.to_le_bytes(); + match skel.maps.tcp_syn_stats.lookup(&key, libbpf_rs::MapFlags::ANY) { + Ok(_) => log::debug!("tcp_syn_stats map is accessible for skeleton {}", i), + Err(e) => log::warn!("tcp_syn_stats map not accessible for skeleton {}: {}", i, e), + } + } + + Ok(()) + } +} + +/// Configuration for TCP fingerprint collection +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintCollectorConfig { + pub enabled: bool, + pub log_interval_secs: u64, + pub fingerprint_events_interval_secs: u64, +} + +impl Default for TcpFingerprintCollectorConfig { + fn default() -> Self { + Self { + enabled: true, + log_interval_secs: 60, // Log stats every minute + fingerprint_events_interval_secs: 30, // Send events every 30 seconds + } + } +} + +impl TcpFingerprintCollectorConfig { + /// Create a new configuration + pub fn new(enabled: bool, log_interval_secs: u64, fingerprint_events_interval_secs: u64) -> Self { + Self { + enabled, + log_interval_secs, + fingerprint_events_interval_secs, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tcp_fingerprint_config_default() { + let config = TcpFingerprintConfig::default(); + assert!(config.enabled); + assert_eq!(config.log_interval_secs, 60); + assert!(config.enable_fingerprint_events); + assert_eq!(config.fingerprint_events_interval_secs, 30); + } + + #[test] + fn test_tcp_fingerprint_collector_config_default() { + let config = TcpFingerprintCollectorConfig::default(); + assert!(config.enabled); + assert_eq!(config.log_interval_secs, 60); + assert_eq!(config.fingerprint_events_interval_secs, 30); + } + + #[test] + fn test_tcp_fingerprint_event_summary() { + let event = TcpFingerprintEvent { + event_type: "tcp_fingerprint".to_string(), + timestamp: Utc::now(), + src_ip: "192.168.1.1".to_string(), + src_port: 80, + fingerprint: "64:1460:65535:7".to_string(), + ttl: 64, + mss: 1460, + window_size: 65535, + window_scale: 7, + packet_count: 1 + }; + + let summary = event.summary(); + assert!(summary.contains("192.168.1.1")); + assert!(summary.contains("80")); + assert!(summary.contains("64:1460:65535:7")); + } + + #[test] + fn test_unique_fingerprint_stats() { + let stats = UniqueFingerprintStats { + timestamp: Utc::now(), + total_unique_patterns: 2, + total_unique_ips: 3, + total_packets: 100, + patterns: vec![ + UniqueFingerprintPattern { + pattern: "64:1460:65535:7".to_string(), + packet_count: 60, + unique_ips: 2, + entries: 2, + }, + UniqueFingerprintPattern { + pattern: "128:1460:32768:8".to_string(), + packet_count: 40, + unique_ips: 1, + entries: 1, + }, + ], + }; + + let summary = stats.summary(); + assert!(summary.contains("2 patterns")); + assert!(summary.contains("3 unique IPs")); + assert!(summary.contains("100 total packets")); + + let json = stats.to_json().unwrap(); + assert!(json.contains("total_unique_patterns")); + assert!(json.contains("patterns")); + } +} diff --git a/src/utils/tcp_fingerprint_noop.rs b/src/utils/tcp_fingerprint_noop.rs index 5a84b34..449dc65 100644 --- a/src/utils/tcp_fingerprint_noop.rs +++ b/src/utils/tcp_fingerprint_noop.rs @@ -1,166 +1,166 @@ -use std::net::IpAddr; -use std::sync::Arc; - -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; - -use crate::bpf::FilterSkel; - -/// TCP fingerprinting configuration -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TcpFingerprintConfig { - pub enabled: bool, - pub log_interval_secs: u64, - pub enable_fingerprint_events: bool, - pub fingerprint_events_interval_secs: u64, - pub min_packet_count: u32, - pub min_connection_duration_secs: u64, -} - -impl Default for TcpFingerprintConfig { - fn default() -> Self { - Self { - enabled: false, - log_interval_secs: 60, - enable_fingerprint_events: false, - fingerprint_events_interval_secs: 30, - min_packet_count: 3, - min_connection_duration_secs: 1, - } - } -} - -impl TcpFingerprintConfig { - pub fn from_cli_config(cli_config: &crate::cli::TcpFingerprintConfig) -> Self { - Self { - enabled: cli_config.enabled, - log_interval_secs: cli_config.log_interval_secs, - enable_fingerprint_events: cli_config.enable_fingerprint_events, - fingerprint_events_interval_secs: cli_config.fingerprint_events_interval_secs, - min_packet_count: cli_config.min_packet_count, - min_connection_duration_secs: cli_config.min_connection_duration_secs, - } - } -} - -/// TCP fingerprint data collected from BPF -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TcpFingerprintData { - pub first_seen: DateTime, - pub last_seen: DateTime, - pub packet_count: u32, - pub ttl: u16, - pub mss: u16, - pub window_size: u16, - pub window_scale: u8, - pub options_len: u8, - pub options: Vec, -} - -/// TCP fingerprint event for API -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TcpFingerprintEvent { - pub event_type: String, - pub timestamp: DateTime, - pub src_ip: String, - pub src_port: u16, - pub fingerprint: String, - pub ttl: u16, - pub mss: u16, - pub window_size: u16, - pub window_scale: u8, - pub packet_count: u32, -} - -impl TcpFingerprintEvent { - pub fn summary(&self) -> String { - format!( - "TCP fingerprint event: {}:{} {}", - self.src_ip, self.src_port, self.fingerprint - ) - } -} - -/// Collection of TCP fingerprint events -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TcpFingerprintEvents { - pub events: Vec, - pub total_events: u64, - pub unique_ips: u64, -} - -impl TcpFingerprintEvents { - pub fn new() -> Self { - Self { - events: Vec::new(), - total_events: 0, - unique_ips: 0, - } - } - - pub fn summary(&self) -> String { - format!( - "TCP fingerprint events: {} events from {} unique IPs", - self.total_events, - self.unique_ips - ) - } - - pub fn get_top_fingerprints(&self, limit: usize) -> Vec { - let mut events = self.events.clone(); - events.sort_by(|a, b| b.packet_count.cmp(&a.packet_count)); - events.into_iter().take(limit).collect() - } -} - -static TCP_FINGERPRINT_COLLECTOR: std::sync::OnceLock> = std::sync::OnceLock::new(); - -pub fn set_global_tcp_fingerprint_collector(collector: TcpFingerprintCollector) { - let _ = TCP_FINGERPRINT_COLLECTOR.set(Arc::new(collector)); -} - -pub fn get_global_tcp_fingerprint_collector() -> Option> { - TCP_FINGERPRINT_COLLECTOR.get().cloned() -} - -#[derive(Clone)] -pub struct TcpFingerprintCollector { - enabled: bool, - _skels: Vec>>, -} - -impl TcpFingerprintCollector { - pub fn new(skels: Vec>>, enabled: bool) -> Self { - Self { enabled, _skels: skels } - } - - pub fn new_with_config(skels: Vec>>, config: TcpFingerprintConfig) -> Self { - Self { - enabled: config.enabled, - _skels: skels, - } - } - - pub fn lookup_fingerprint(&self, _src_ip: IpAddr, _src_port: u16) -> Option { - None - } - - pub fn collect_fingerprint_events(&self) -> Result> { - Ok(TcpFingerprintEvents::new()) - } - - pub fn log_stats(&self) -> Result<(), Box> { - if self.enabled { - log::debug!("TCP fingerprint stats disabled (BPF support not built)"); - } - Ok(()) - } - - pub fn log_fingerprint_events(&self) -> Result<(), Box> { - Ok(()) - } - - pub fn log_events(&self) -> Result<(), Box> { - Ok(()) - } -} +use std::net::IpAddr; +use std::sync::Arc; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use crate::bpf::FilterSkel; + +/// TCP fingerprinting configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintConfig { + pub enabled: bool, + pub log_interval_secs: u64, + pub enable_fingerprint_events: bool, + pub fingerprint_events_interval_secs: u64, + pub min_packet_count: u32, + pub min_connection_duration_secs: u64, +} + +impl Default for TcpFingerprintConfig { + fn default() -> Self { + Self { + enabled: false, + log_interval_secs: 60, + enable_fingerprint_events: false, + fingerprint_events_interval_secs: 30, + min_packet_count: 3, + min_connection_duration_secs: 1, + } + } +} + +impl TcpFingerprintConfig { + pub fn from_cli_config(cli_config: &crate::cli::TcpFingerprintConfig) -> Self { + Self { + enabled: cli_config.enabled, + log_interval_secs: cli_config.log_interval_secs, + enable_fingerprint_events: cli_config.enable_fingerprint_events, + fingerprint_events_interval_secs: cli_config.fingerprint_events_interval_secs, + min_packet_count: cli_config.min_packet_count, + min_connection_duration_secs: cli_config.min_connection_duration_secs, + } + } +} + +/// TCP fingerprint data collected from BPF +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintData { + pub first_seen: DateTime, + pub last_seen: DateTime, + pub packet_count: u32, + pub ttl: u16, + pub mss: u16, + pub window_size: u16, + pub window_scale: u8, + pub options_len: u8, + pub options: Vec, +} + +/// TCP fingerprint event for API +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintEvent { + pub event_type: String, + pub timestamp: DateTime, + pub src_ip: String, + pub src_port: u16, + pub fingerprint: String, + pub ttl: u16, + pub mss: u16, + pub window_size: u16, + pub window_scale: u8, + pub packet_count: u32, +} + +impl TcpFingerprintEvent { + pub fn summary(&self) -> String { + format!( + "TCP fingerprint event: {}:{} {}", + self.src_ip, self.src_port, self.fingerprint + ) + } +} + +/// Collection of TCP fingerprint events +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpFingerprintEvents { + pub events: Vec, + pub total_events: u64, + pub unique_ips: u64, +} + +impl TcpFingerprintEvents { + pub fn new() -> Self { + Self { + events: Vec::new(), + total_events: 0, + unique_ips: 0, + } + } + + pub fn summary(&self) -> String { + format!( + "TCP fingerprint events: {} events from {} unique IPs", + self.total_events, + self.unique_ips + ) + } + + pub fn get_top_fingerprints(&self, limit: usize) -> Vec { + let mut events = self.events.clone(); + events.sort_by(|a, b| b.packet_count.cmp(&a.packet_count)); + events.into_iter().take(limit).collect() + } +} + +static TCP_FINGERPRINT_COLLECTOR: std::sync::OnceLock> = std::sync::OnceLock::new(); + +pub fn set_global_tcp_fingerprint_collector(collector: TcpFingerprintCollector) { + let _ = TCP_FINGERPRINT_COLLECTOR.set(Arc::new(collector)); +} + +pub fn get_global_tcp_fingerprint_collector() -> Option> { + TCP_FINGERPRINT_COLLECTOR.get().cloned() +} + +#[derive(Clone)] +pub struct TcpFingerprintCollector { + enabled: bool, + _skels: Vec>>, +} + +impl TcpFingerprintCollector { + pub fn new(skels: Vec>>, enabled: bool) -> Self { + Self { enabled, _skels: skels } + } + + pub fn new_with_config(skels: Vec>>, config: TcpFingerprintConfig) -> Self { + Self { + enabled: config.enabled, + _skels: skels, + } + } + + pub fn lookup_fingerprint(&self, _src_ip: IpAddr, _src_port: u16) -> Option { + None + } + + pub fn collect_fingerprint_events(&self) -> Result> { + Ok(TcpFingerprintEvents::new()) + } + + pub fn log_stats(&self) -> Result<(), Box> { + if self.enabled { + log::debug!("TCP fingerprint stats disabled (BPF support not built)"); + } + Ok(()) + } + + pub fn log_fingerprint_events(&self) -> Result<(), Box> { + Ok(()) + } + + pub fn log_events(&self) -> Result<(), Box> { + Ok(()) + } +} diff --git a/src/utils/tls.rs b/src/utils/tls.rs index 69d86b3..3c479e2 100644 --- a/src/utils/tls.rs +++ b/src/utils/tls.rs @@ -1,652 +1,652 @@ -use dashmap::DashMap; -use log::{error, info, warn}; -use pingora_core::tls::ssl::{select_next_proto, AlpnError, NameType, SniError, SslAlert, SslContext, SslFiletype, SslMethod, SslRef, SslVersion}; -use pingora_core::listeners::tls::TlsSettings; -use pingora_core::listeners::TlsAccept; -use rustls_pemfile::{read_one, Item}; -use serde::Deserialize; -use std::collections::HashSet; -use std::fs::File; -use std::io::BufReader; -use std::sync::Arc; -use once_cell::sync::OnceCell; -use async_trait::async_trait; -use x509_parser::extensions::GeneralName; -use x509_parser::nom::Err as NomErr; -use x509_parser::prelude::*; - -// Global certificate store for SNI callback -static GLOBAL_CERTIFICATES: OnceCell> = OnceCell::new(); - -/// Set the global certificates for SNI callback -pub fn set_global_certificates(certificates: Arc) { - let _ = GLOBAL_CERTIFICATES.set(certificates); -} - -/// Get the global certificates for SNI callback -fn get_global_certificates() -> Option> { - GLOBAL_CERTIFICATES.get().cloned() -} - -#[derive(Clone, Deserialize, Debug)] -pub struct CertificateConfig { - pub cert_path: String, - pub key_path: String, -} - -#[derive(Clone, Debug)] -struct CertificateInfo { - common_names: Vec, - alt_names: Vec, - ssl_context: SslContext, - cert_path: String, - #[allow(dead_code)] // Used during Certificates initialization - key_path: String, -} - -#[derive(Clone, Debug)] -pub struct Certificates { - configs: Vec, - name_map: DashMap, - // Map from certificate name (e.g., "arxignis.dev") to SSL context - cert_name_map: DashMap, - // Map from hostname (e.g., "david-playground3.arxignis.dev") to certificate name (e.g., "arxignis.dev") - upstreams_cert_map: DashMap, - pub default_cert_path: String, - pub default_key_path: String, -} - -// Implement TlsAccept trait for dynamic certificate selection based on SNI -#[async_trait] -impl TlsAccept for Certificates { - async fn certificate_callback(&self, ssl: &mut SslRef) { - if let Some(server_name) = ssl.servername(NameType::HOST_NAME) { - let name_str = server_name.to_string(); - log::info!("TlsAccept::certificate_callback invoked for hostname: {}", name_str); - log::debug!("TlsAccept: upstreams_cert_map has {} entries", self.upstreams_cert_map.len()); - log::debug!("TlsAccept: cert_name_map has {} entries", self.cert_name_map.len()); - - // Find the matching SSL context for this hostname - if let Some(ctx) = self.find_ssl_context(&name_str) { - // Log which certificate was found (will be logged in find_ssl_context) - log::info!("TlsAccept: Found matching certificate for hostname: {} (see details above)", name_str); - - // Get the certificate and key from the SSL context - // We need to extract them from the context to use with ssl_use_certificate - // However, SslContext doesn't expose the certificate/key directly - // So we'll use set_ssl_context instead, which should work - match ssl.set_ssl_context(&*ctx) { - Ok(_) => { - log::info!("TlsAccept: Successfully set SSL context for hostname: {}", name_str); - return; - } - Err(e) => { - log::error!("TlsAccept: Failed to set SSL context for hostname {}: {:?}", name_str, e); - // Fall through to use default certificate - } - } - } else { - log::warn!("TlsAccept: No matching certificate found for hostname: {}, using default", name_str); - } - } else { - log::debug!("TlsAccept: No SNI provided, using default certificate"); - } - - // Use default certificate - get it by name from default_cert_path - let default_cert_name = std::path::Path::new(&self.default_cert_path) - .file_stem() - .and_then(|s| s.to_str()) - .unwrap_or("default"); - - if let Some(default_ctx) = self.cert_name_map.get(default_cert_name) { - let ctx = default_ctx.value(); - log::info!("TlsAccept: Using configured default certificate: {}", default_cert_name); - if let Err(e) = ssl.set_ssl_context(&*ctx) { - log::error!("TlsAccept: Failed to set default SSL context: {:?}", e); - } else { - log::debug!("TlsAccept: Successfully set default certificate"); - } - } else { - // Fallback to first available certificate if default not found - log::warn!("TlsAccept: Default certificate '{}' not found in cert_name_map, using first available", default_cert_name); - if let Some(default_ctx) = self.cert_name_map.iter().next() { - let ctx = default_ctx.value(); - if let Err(e) = ssl.set_ssl_context(&*ctx) { - log::error!("TlsAccept: Failed to set fallback SSL context: {:?}", e); - } else { - log::debug!("TlsAccept: Using fallback certificate"); - } - } else { - log::error!("TlsAccept: No certificates available!"); - } - } - } -} - -impl Certificates { - pub fn new(configs: &Vec, _grade: &str, default_certificate: Option<&String>) -> Option { - Self::new_with_sni_callback(configs, _grade, default_certificate, None) - } - - pub fn new_with_sni_callback( - configs: &Vec, - _grade: &str, - default_certificate: Option<&String>, - _certificates_for_callback: Option>, - ) -> Option { - if configs.is_empty() { - warn!("No TLS certificates found, TLS will be disabled until certificates are added"); - return None; - } - - // First, create a temporary Certificates struct to get access to it in the callback - // We'll recreate it properly after loading all certificates - let mut cert_infos = Vec::new(); - let name_map: DashMap = DashMap::new(); - let mut valid_configs = Vec::new(); - - for config in configs { - let cert_info = load_cert_info(&config.cert_path, &config.key_path, _grade); - match cert_info { - Some(cert) => { - for name in &cert.common_names { - name_map.insert(name.clone(), cert.ssl_context.clone()); - } - for name in &cert.alt_names { - name_map.insert(name.clone(), cert.ssl_context.clone()); - } - - cert_infos.push(cert); - valid_configs.push(config.clone()); - } - None => { - warn!("Skipping invalid certificate: cert={}, key={}", &config.cert_path, &config.key_path); - // Continue with other certificates instead of failing - } - } - } - - if cert_infos.is_empty() { - error!("No valid certificates could be loaded from {} certificate configs", configs.len()); - return None; - } - - // Find default certificate: use configured default_certificate if specified, otherwise use first valid certificate - let default_cert = if let Some(default_cert_name) = default_certificate { - // Try to find certificate by name (file stem without extension) - let found = valid_configs.iter().find(|config| { - if let Some(file_name) = std::path::Path::new(&config.cert_path) - .file_stem() - .and_then(|s| s.to_str()) - { - file_name == default_cert_name.as_str() - } else { - false - } - }); - match found { - Some(cert) => { - log::info!("Using configured default certificate: {}", default_cert_name); - cert - } - None => { - log::warn!("Configured default certificate '{}' not found, using first valid certificate", default_cert_name); - &valid_configs[0] - } - } - } else { - // Use first valid certificate as default - &valid_configs[0] - }; - - // Build cert_name_map: map from certificate file name (without extension) to SSL context - let cert_name_map: DashMap = DashMap::new(); - for (idx, config) in valid_configs.iter().enumerate() { - // Extract certificate name from path (e.g., "/path/to/arxignis.dev.crt" -> "arxignis.dev") - // Use file_stem() to get the filename without extension - if let Some(file_name) = std::path::Path::new(&config.cert_path) - .file_stem() - .and_then(|s| s.to_str()) - { - if let Some(cert_info) = cert_infos.get(idx) { - let cert_name = file_name.to_string(); - cert_name_map.insert(cert_name.clone(), cert_info.ssl_context.clone()); - log::debug!("Mapped certificate name '{}' to SSL context (from path: {})", cert_name, config.cert_path); - } - } else { - log::warn!("Failed to extract certificate name from path: {}", config.cert_path); - } - } - - log::debug!("Built cert_name_map with {} entries", cert_name_map.len()); - - Some(Self { - name_map: name_map, - cert_name_map: cert_name_map, - upstreams_cert_map: DashMap::new(), - configs: cert_infos, - default_cert_path: default_cert.cert_path.clone(), - default_key_path: default_cert.key_path.clone(), - }) - } - - /// Set upstreams certificate mappings (hostname -> certificate_name) - /// The certificate_name should match the file stem used in cert_name_map - /// (i.e., normalized and sanitized: remove wildcard prefix, replace . with _) - pub fn set_upstreams_cert_map(&self, mappings: DashMap) { - self.upstreams_cert_map.clear(); - for entry in mappings.iter() { - let hostname = entry.key().clone(); - let cert_name = entry.value().clone(); - // Normalize certificate name to match file stem format used in cert_name_map - // Remove wildcard prefix if present, then sanitize (replace . with _) - let normalized_cert_name = cert_name.strip_prefix("*.").unwrap_or(&cert_name); - let sanitized_cert_name = normalized_cert_name.replace('.', "_").replace('*', "_"); - self.upstreams_cert_map.insert(hostname.clone(), sanitized_cert_name.clone()); - log::info!("Mapped hostname '{}' to certificate '{}' (normalized from '{}')", hostname, sanitized_cert_name, cert_name); - } - log::info!("Set upstreams certificate mappings: {} entries", self.upstreams_cert_map.len()); - } - - fn find_ssl_context(&self, server_name: &str) -> Option { - log::debug!("Finding SSL context for server_name: {}", server_name); - log::debug!("upstreams_cert_map entries: {:?}", - self.upstreams_cert_map.iter().map(|e| (e.key().clone(), e.value().clone())).collect::>()); - log::debug!("cert_name_map entries: {:?}", - self.cert_name_map.iter().map(|e| e.key().clone()).collect::>()); - - // First, check if there's an upstreams mapping for this hostname - if let Some(cert_name) = self.upstreams_cert_map.get(server_name) { - let cert_name_str = cert_name.value(); - log::info!("Found upstreams mapping: {} -> {}", server_name, cert_name_str); - if let Some(ctx) = self.cert_name_map.get(cert_name_str) { - log::info!("Using certificate '{}' for hostname '{}' via upstreams mapping", cert_name_str, server_name); - return Some(ctx.clone()); - } else { - // Certificate specified in upstreams.yaml but doesn't exist - use default instead of searching further - log::warn!("Certificate '{}' specified in upstreams config for hostname '{}' not found in cert_name_map. Available certificates: {:?}. Will use default certificate (NOT searching for wildcards).", - cert_name_str, server_name, - self.cert_name_map.iter().map(|e| e.key().clone()).collect::>()); - return None; // Return None to use default certificate - DO NOT continue searching - } - } else { - log::debug!("No upstreams mapping found for hostname: {}, will search for exact/wildcard matches", server_name); - } - - // Then, try exact match in name_map (from certificate CN/SAN) - if let Some(ctx) = self.name_map.get(server_name) { - log::info!("Found certificate via CN/SAN exact match for: {}", server_name); - return Some(ctx.clone()); - } - - // Try wildcard match from certificate CN/SAN before falling back to default - for config in &self.configs { - for name in &config.common_names { - if name.starts_with("*.") && server_name.ends_with(&name[1..]) { - log::info!("Found certificate via CN wildcard match: {} matches {}", server_name, name); - return Some(config.ssl_context.clone()); - } - } - for name in &config.alt_names { - if name.starts_with("*.") && server_name.ends_with(&name[1..]) { - log::info!("Found certificate via SAN wildcard match: {} matches {}", server_name, name); - return Some(config.ssl_context.clone()); - } - } - } - - // Check if default certificate is configured - use it as fallback - let default_cert_name = std::path::Path::new(&self.default_cert_path) - .file_stem() - .and_then(|s| s.to_str()) - .unwrap_or("default"); - - // If default certificate exists and is configured, use it as fallback - if self.cert_name_map.contains_key(default_cert_name) { - log::info!("No exact or wildcard match found for '{}', will use default certificate '{}'", server_name, default_cert_name); - return None; // Return None to use default certificate - } - - log::warn!("No matching certificate found for hostname: {}, will use default certificate", server_name); - None - } - - pub fn server_name_callback(&self, ssl_ref: &mut SslRef, _ssl_alert: &mut SslAlert) -> Result<(), SniError> { - let server_name_opt = ssl_ref.servername(NameType::HOST_NAME); - log::info!("TLS server_name_callback invoked: server_name = {:?}", server_name_opt); - if let Some(name) = server_name_opt { - let name_str = name.to_string(); - log::info!("SNI callback: Looking up certificate for hostname: {}", name_str); - log::debug!("SNI callback: upstreams_cert_map has {} entries", self.upstreams_cert_map.len()); - log::debug!("SNI callback: cert_name_map has {} entries", self.cert_name_map.len()); - - match self.find_ssl_context(&name_str) { - Some(ctx) => { - log::info!("SNI callback: Found matching certificate for hostname: {}", name_str); - log::info!("SNI callback: Setting SSL context for hostname: {}", name_str); - ssl_ref.set_ssl_context(&*ctx).map_err(|e| { - log::error!("SNI callback: Failed to set SSL context for hostname {}: {:?}", name_str, e); - SniError::ALERT_FATAL - })?; - log::info!("SNI callback: Successfully set SSL context for hostname: {}", name_str); - } - None => { - log::warn!("SNI callback: No matching certificate found for hostname: {}, using default certificate", name_str); - log::debug!("SNI callback: Available upstreams mappings: {:?}", - self.upstreams_cert_map.iter().map(|e| (e.key().clone(), e.value().clone())).collect::>()); - log::debug!("SNI callback: Available certificate names: {:?}", - self.cert_name_map.iter().map(|e| e.key().clone()).collect::>()); - // Don't set a context - let it use the default - } - } - } else { - log::debug!("SNI callback: No server name (SNI) provided in TLS handshake"); - } - Ok(()) - } - - /// Get certificate path for a given hostname - pub fn get_cert_path_for_hostname(&self, hostname: &str) -> Option { - // First try exact match - if self.name_map.contains_key(hostname) { - // Find the certificate info that matches this hostname - for config in &self.configs { - if config.common_names.contains(&hostname.to_string()) || config.alt_names.contains(&hostname.to_string()) { - return Some(config.cert_path.clone()); - } - } - } - - // Try wildcard match - for config in &self.configs { - for name in &config.common_names { - if name.starts_with("*.") && hostname.ends_with(&name[1..]) { - return Some(config.cert_path.clone()); - } - } - for name in &config.alt_names { - if name.starts_with("*.") && hostname.ends_with(&name[1..]) { - return Some(config.cert_path.clone()); - } - } - } - - // Return default certificate path if no match found - Some(self.default_cert_path.clone()) - } -} - -fn load_cert_info(cert_path: &str, key_path: &str, _grade: &str) -> Option { - let mut common_names = HashSet::new(); - let mut alt_names = HashSet::new(); - - let file = File::open(cert_path); - match file { - Err(e) => { - log::error!("Failed to open certificate file: {:?}", e); - return None; - } - Ok(file) => { - let mut reader = BufReader::new(file); - match read_one(&mut reader) { - Err(e) => { - log::error!("Failed to decode PEM from certificate file: {:?}", e); - return None; - } - Ok(leaf) => match leaf { - Some(Item::X509Certificate(cert)) => match X509Certificate::from_der(&cert) { - Err(NomErr::Error(e)) | Err(NomErr::Failure(e)) => { - log::error!("Failed to parse certificate: {:?}", e); - return None; - } - Err(_) => { - log::error!("Unknown error while parsing certificate"); - return None; - } - Ok((_, x509)) => { - let subject = x509.subject(); - for attr in subject.iter_common_name() { - if let Ok(cn) = attr.as_str() { - common_names.insert(cn.to_string()); - } - } - - if let Ok(Some(san)) = x509.subject_alternative_name() { - for name in san.value.general_names.iter() { - if let GeneralName::DNSName(dns) = name { - let dns_string = dns.to_string(); - if !common_names.contains(&dns_string) { - alt_names.insert(dns_string); - } - } - } - } - } - }, - _ => { - log::error!("Failed to read certificate"); - return None; - } - }, - } - } - } - - match create_ssl_context(cert_path, key_path) { - Ok(ssl_context) => { - Some(CertificateInfo { - common_names: common_names.into_iter().collect(), - alt_names: alt_names.into_iter().collect(), - ssl_context, - cert_path: cert_path.to_string(), - key_path: key_path.to_string(), - }) - } - Err(e) => { - log::error!("Failed to create SSL context from cert paths '{}' and '{}': {}", cert_path, key_path, e); - None - } - } -} - -fn create_ssl_context(cert_path: &str, key_path: &str) -> Result> { - // Always try to use global certificates for SNI callback - // This ensures that even contexts created without explicit certificates - // will have the SNI callback set if global certificates are available - create_ssl_context_with_sni_callback(cert_path, key_path, None) -} - -fn create_ssl_context_with_sni_callback( - cert_path: &str, - key_path: &str, - certificates: Option>, -) -> Result> { - let mut ctx = SslContext::builder(SslMethod::tls()) - .map_err(|e| format!("Failed to create SSL context builder: {}", e))?; - - ctx.set_certificate_chain_file(cert_path) - .map_err(|e| format!("Failed to set certificate chain file '{}': {}", cert_path, e))?; - - ctx.set_private_key_file(key_path, SslFiletype::PEM) - .map_err(|e| format!("Failed to set private key file '{}': {}", key_path, e))?; - - ctx.set_alpn_select_callback(prefer_h2); - - // Set SNI callback - use provided certificates or global certificates - let certs_for_callback = certificates.or_else(get_global_certificates); - if let Some(certs) = certs_for_callback { - let certs_clone = certs.clone(); - ctx.set_servername_callback(move |ssl_ref: &mut SslRef, _ssl_alert: &mut SslAlert| -> Result<(), SniError> { - certs_clone.server_name_callback(ssl_ref, _ssl_alert) - }); - log::debug!("Set SNI callback on SSL context for certificate selection"); - } else { - // Certificates may not be loaded yet (e.g., during startup before Redis certificates are fetched) - // This is expected during initialization, so use debug level instead of warn - static WARNED: std::sync::Once = std::sync::Once::new(); - WARNED.call_once(|| { - log::debug!("No certificates available for SNI callback yet - certificates will be loaded asynchronously. Certificate selection by hostname will work once certificates are loaded."); - }); - } - - let built = ctx.build(); - - Ok(built) -} - -#[derive(Debug)] -pub struct CipherSuite { - pub high: &'static str, - pub medium: &'static str, - pub legacy: &'static str, -} -const CIPHERS: CipherSuite = CipherSuite { - high: "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256:TLS_AES_128_GCM_SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305", - medium: "ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES128-SHA:AES128-GCM-SHA256", - legacy: "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH", -}; - -#[derive(Debug)] -pub enum TlsGrade { - HIGH, - MEDIUM, - LEGACY, -} - -impl TlsGrade { - pub fn from_str(s: &str) -> Option { - match s.to_ascii_lowercase().as_str() { - "high" => Some(TlsGrade::HIGH), - "medium" => Some(TlsGrade::MEDIUM), - "unsafe" => Some(TlsGrade::LEGACY), - _ => None, - } - } -} -pub fn prefer_h2<'a>(_ssl: &mut SslRef, alpn_in: &'a [u8]) -> Result<&'a [u8], AlpnError> { - match select_next_proto("\x02h2\x08http/1.1".as_bytes(), alpn_in) { - Some(p) => Ok(p), - _ => Err(AlpnError::NOACK), - } -} - -// Helper to set ALPN on TlsSettings -pub fn set_alpn_prefer_h2(tls_settings: &mut pingora_core::listeners::tls::TlsSettings) { - use pingora_core::listeners::ALPN; - tls_settings.set_alpn(ALPN::H2H1); -} - -// Helper to create TlsSettings with SNI callback for certificate selection -// This uses TlsSettings::with_callbacks() which allows us to provide a TlsAccept implementation -// that handles dynamic certificate selection based on SNI (Server Name Indication) -pub fn create_tls_settings_with_sni( - cert_path: &str, - key_path: &str, - grade: &str, - certificates: Option>, -) -> Result> { - // Get the certificates - use provided or fall back to global - let certs = certificates - .or_else(get_global_certificates) - .ok_or_else(|| "No certificates available for TLS configuration".to_string())?; - - log::info!("Creating TlsSettings with callbacks for dynamic certificate selection"); - log::info!("Default certificate: {} / {}", cert_path, key_path); - log::info!("Certificate mappings: {} upstreams, {} certificates", - certs.upstreams_cert_map.len(), certs.cert_name_map.len()); - - // Use TlsSettings::with_callbacks() instead of TlsSettings::intermediate() - // This allows us to provide our Certificates struct which implements TlsAccept - // The certificate_callback method will be called during TLS handshake to select - // the appropriate certificate based on the SNI hostname - // - // Note: with_callbacks expects a Box - // We clone the Certificates struct to create a new instance for the callback - let tls_accept: Box = Box::new((*certs).clone()); - let mut tls_settings = TlsSettings::with_callbacks(tls_accept) - .map_err(|e| format!("Failed to create TlsSettings with callbacks: {}", e))?; - - // Configure TLS grade and ALPN - set_tsl_grade(&mut tls_settings, grade); - set_alpn_prefer_h2(&mut tls_settings); - - log::info!("Successfully created TlsSettings with SNI-based certificate selection"); - log::info!("Certificate selection will work based on hostname from SNI"); - - Ok(tls_settings) -} - -pub fn set_tsl_grade(tls_settings: &mut TlsSettings, grade: &str) { - let config_grade = TlsGrade::from_str(grade); - match config_grade { - Some(TlsGrade::HIGH) => { - let _ = tls_settings.set_min_proto_version(Some(SslVersion::TLS1_2)); - // let _ = tls_settings.set_max_proto_version(Some(SslVersion::TLS1_3)); - let _ = tls_settings.set_cipher_list(CIPHERS.high); - let _ = tls_settings.set_ciphersuites(CIPHERS.high); - info!("TLS grade: => HIGH"); - } - Some(TlsGrade::MEDIUM) => { - let _ = tls_settings.set_min_proto_version(Some(SslVersion::TLS1)); - let _ = tls_settings.set_cipher_list(CIPHERS.medium); - let _ = tls_settings.set_ciphersuites(CIPHERS.medium); - info!("TLS grade: => MEDIUM"); - } - Some(TlsGrade::LEGACY) => { - let _ = tls_settings.set_min_proto_version(Some(SslVersion::SSL3)); - let _ = tls_settings.set_cipher_list(CIPHERS.legacy); - let _ = tls_settings.set_ciphersuites(CIPHERS.legacy); - warn!("TLS grade: => UNSAFE"); - } - None => { - // Defaults to MEDIUM - let _ = tls_settings.set_min_proto_version(Some(SslVersion::TLS1)); - let _ = tls_settings.set_cipher_list(CIPHERS.medium); - let _ = tls_settings.set_ciphersuites(CIPHERS.medium); - warn!("TLS grade is not detected defaulting top MEDIUM"); - } - } -} - -/// Extract server certificate information for access logging -pub fn extract_cert_info(cert_path: &str) -> Option { - use sha2::{Digest, Sha256}; - - let file = File::open(cert_path).ok()?; - let mut reader = BufReader::new(file); - - // Read the first certificate from the PEM file - let item = read_one(&mut reader).ok()??; - - let cert_der = match item { - Item::X509Certificate(der) => der, - _ => return None, - }; - - // Parse the X.509 certificate - let (_, cert) = X509Certificate::from_der(&cert_der).ok()?; - - // Extract issuer - let issuer = cert.issuer().to_string(); - - // Extract subject - let subject = cert.subject().to_string(); - - // Extract validity dates (as ISO 8601 format) - let not_before = cert.validity().not_before.to_datetime().to_string(); - let not_after = cert.validity().not_after.to_datetime().to_string(); - - // Calculate SHA256 fingerprint - let mut hasher = Sha256::new(); - hasher.update(&cert_der); - let fingerprint_sha256 = format!("{:x}", hasher.finalize()); - - Some(crate::access_log::ServerCertInfo { - issuer, - subject, - not_before, - not_after, - fingerprint_sha256, - }) -} - +use dashmap::DashMap; +use log::{error, info, warn}; +use pingora_core::tls::ssl::{select_next_proto, AlpnError, NameType, SniError, SslAlert, SslContext, SslFiletype, SslMethod, SslRef, SslVersion}; +use pingora_core::listeners::tls::TlsSettings; +use pingora_core::listeners::TlsAccept; +use rustls_pemfile::{read_one, Item}; +use serde::Deserialize; +use std::collections::HashSet; +use std::fs::File; +use std::io::BufReader; +use std::sync::Arc; +use once_cell::sync::OnceCell; +use async_trait::async_trait; +use x509_parser::extensions::GeneralName; +use x509_parser::nom::Err as NomErr; +use x509_parser::prelude::*; + +// Global certificate store for SNI callback +static GLOBAL_CERTIFICATES: OnceCell> = OnceCell::new(); + +/// Set the global certificates for SNI callback +pub fn set_global_certificates(certificates: Arc) { + let _ = GLOBAL_CERTIFICATES.set(certificates); +} + +/// Get the global certificates for SNI callback +fn get_global_certificates() -> Option> { + GLOBAL_CERTIFICATES.get().cloned() +} + +#[derive(Clone, Deserialize, Debug)] +pub struct CertificateConfig { + pub cert_path: String, + pub key_path: String, +} + +#[derive(Clone, Debug)] +struct CertificateInfo { + common_names: Vec, + alt_names: Vec, + ssl_context: SslContext, + cert_path: String, + #[allow(dead_code)] // Used during Certificates initialization + key_path: String, +} + +#[derive(Clone, Debug)] +pub struct Certificates { + configs: Vec, + name_map: DashMap, + // Map from certificate name (e.g., "arxignis.dev") to SSL context + cert_name_map: DashMap, + // Map from hostname (e.g., "david-playground3.arxignis.dev") to certificate name (e.g., "arxignis.dev") + upstreams_cert_map: DashMap, + pub default_cert_path: String, + pub default_key_path: String, +} + +// Implement TlsAccept trait for dynamic certificate selection based on SNI +#[async_trait] +impl TlsAccept for Certificates { + async fn certificate_callback(&self, ssl: &mut SslRef) { + if let Some(server_name) = ssl.servername(NameType::HOST_NAME) { + let name_str = server_name.to_string(); + log::info!("TlsAccept::certificate_callback invoked for hostname: {}", name_str); + log::debug!("TlsAccept: upstreams_cert_map has {} entries", self.upstreams_cert_map.len()); + log::debug!("TlsAccept: cert_name_map has {} entries", self.cert_name_map.len()); + + // Find the matching SSL context for this hostname + if let Some(ctx) = self.find_ssl_context(&name_str) { + // Log which certificate was found (will be logged in find_ssl_context) + log::info!("TlsAccept: Found matching certificate for hostname: {} (see details above)", name_str); + + // Get the certificate and key from the SSL context + // We need to extract them from the context to use with ssl_use_certificate + // However, SslContext doesn't expose the certificate/key directly + // So we'll use set_ssl_context instead, which should work + match ssl.set_ssl_context(&*ctx) { + Ok(_) => { + log::info!("TlsAccept: Successfully set SSL context for hostname: {}", name_str); + return; + } + Err(e) => { + log::error!("TlsAccept: Failed to set SSL context for hostname {}: {:?}", name_str, e); + // Fall through to use default certificate + } + } + } else { + log::warn!("TlsAccept: No matching certificate found for hostname: {}, using default", name_str); + } + } else { + log::debug!("TlsAccept: No SNI provided, using default certificate"); + } + + // Use default certificate - get it by name from default_cert_path + let default_cert_name = std::path::Path::new(&self.default_cert_path) + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("default"); + + if let Some(default_ctx) = self.cert_name_map.get(default_cert_name) { + let ctx = default_ctx.value(); + log::info!("TlsAccept: Using configured default certificate: {}", default_cert_name); + if let Err(e) = ssl.set_ssl_context(&*ctx) { + log::error!("TlsAccept: Failed to set default SSL context: {:?}", e); + } else { + log::debug!("TlsAccept: Successfully set default certificate"); + } + } else { + // Fallback to first available certificate if default not found + log::warn!("TlsAccept: Default certificate '{}' not found in cert_name_map, using first available", default_cert_name); + if let Some(default_ctx) = self.cert_name_map.iter().next() { + let ctx = default_ctx.value(); + if let Err(e) = ssl.set_ssl_context(&*ctx) { + log::error!("TlsAccept: Failed to set fallback SSL context: {:?}", e); + } else { + log::debug!("TlsAccept: Using fallback certificate"); + } + } else { + log::error!("TlsAccept: No certificates available!"); + } + } + } +} + +impl Certificates { + pub fn new(configs: &Vec, _grade: &str, default_certificate: Option<&String>) -> Option { + Self::new_with_sni_callback(configs, _grade, default_certificate, None) + } + + pub fn new_with_sni_callback( + configs: &Vec, + _grade: &str, + default_certificate: Option<&String>, + _certificates_for_callback: Option>, + ) -> Option { + if configs.is_empty() { + warn!("No TLS certificates found, TLS will be disabled until certificates are added"); + return None; + } + + // First, create a temporary Certificates struct to get access to it in the callback + // We'll recreate it properly after loading all certificates + let mut cert_infos = Vec::new(); + let name_map: DashMap = DashMap::new(); + let mut valid_configs = Vec::new(); + + for config in configs { + let cert_info = load_cert_info(&config.cert_path, &config.key_path, _grade); + match cert_info { + Some(cert) => { + for name in &cert.common_names { + name_map.insert(name.clone(), cert.ssl_context.clone()); + } + for name in &cert.alt_names { + name_map.insert(name.clone(), cert.ssl_context.clone()); + } + + cert_infos.push(cert); + valid_configs.push(config.clone()); + } + None => { + warn!("Skipping invalid certificate: cert={}, key={}", &config.cert_path, &config.key_path); + // Continue with other certificates instead of failing + } + } + } + + if cert_infos.is_empty() { + error!("No valid certificates could be loaded from {} certificate configs", configs.len()); + return None; + } + + // Find default certificate: use configured default_certificate if specified, otherwise use first valid certificate + let default_cert = if let Some(default_cert_name) = default_certificate { + // Try to find certificate by name (file stem without extension) + let found = valid_configs.iter().find(|config| { + if let Some(file_name) = std::path::Path::new(&config.cert_path) + .file_stem() + .and_then(|s| s.to_str()) + { + file_name == default_cert_name.as_str() + } else { + false + } + }); + match found { + Some(cert) => { + log::info!("Using configured default certificate: {}", default_cert_name); + cert + } + None => { + log::warn!("Configured default certificate '{}' not found, using first valid certificate", default_cert_name); + &valid_configs[0] + } + } + } else { + // Use first valid certificate as default + &valid_configs[0] + }; + + // Build cert_name_map: map from certificate file name (without extension) to SSL context + let cert_name_map: DashMap = DashMap::new(); + for (idx, config) in valid_configs.iter().enumerate() { + // Extract certificate name from path (e.g., "/path/to/arxignis.dev.crt" -> "arxignis.dev") + // Use file_stem() to get the filename without extension + if let Some(file_name) = std::path::Path::new(&config.cert_path) + .file_stem() + .and_then(|s| s.to_str()) + { + if let Some(cert_info) = cert_infos.get(idx) { + let cert_name = file_name.to_string(); + cert_name_map.insert(cert_name.clone(), cert_info.ssl_context.clone()); + log::debug!("Mapped certificate name '{}' to SSL context (from path: {})", cert_name, config.cert_path); + } + } else { + log::warn!("Failed to extract certificate name from path: {}", config.cert_path); + } + } + + log::debug!("Built cert_name_map with {} entries", cert_name_map.len()); + + Some(Self { + name_map: name_map, + cert_name_map: cert_name_map, + upstreams_cert_map: DashMap::new(), + configs: cert_infos, + default_cert_path: default_cert.cert_path.clone(), + default_key_path: default_cert.key_path.clone(), + }) + } + + /// Set upstreams certificate mappings (hostname -> certificate_name) + /// The certificate_name should match the file stem used in cert_name_map + /// (i.e., normalized and sanitized: remove wildcard prefix, replace . with _) + pub fn set_upstreams_cert_map(&self, mappings: DashMap) { + self.upstreams_cert_map.clear(); + for entry in mappings.iter() { + let hostname = entry.key().clone(); + let cert_name = entry.value().clone(); + // Normalize certificate name to match file stem format used in cert_name_map + // Remove wildcard prefix if present, then sanitize (replace . with _) + let normalized_cert_name = cert_name.strip_prefix("*.").unwrap_or(&cert_name); + let sanitized_cert_name = normalized_cert_name.replace('.', "_").replace('*', "_"); + self.upstreams_cert_map.insert(hostname.clone(), sanitized_cert_name.clone()); + log::info!("Mapped hostname '{}' to certificate '{}' (normalized from '{}')", hostname, sanitized_cert_name, cert_name); + } + log::info!("Set upstreams certificate mappings: {} entries", self.upstreams_cert_map.len()); + } + + fn find_ssl_context(&self, server_name: &str) -> Option { + log::debug!("Finding SSL context for server_name: {}", server_name); + log::debug!("upstreams_cert_map entries: {:?}", + self.upstreams_cert_map.iter().map(|e| (e.key().clone(), e.value().clone())).collect::>()); + log::debug!("cert_name_map entries: {:?}", + self.cert_name_map.iter().map(|e| e.key().clone()).collect::>()); + + // First, check if there's an upstreams mapping for this hostname + if let Some(cert_name) = self.upstreams_cert_map.get(server_name) { + let cert_name_str = cert_name.value(); + log::info!("Found upstreams mapping: {} -> {}", server_name, cert_name_str); + if let Some(ctx) = self.cert_name_map.get(cert_name_str) { + log::info!("Using certificate '{}' for hostname '{}' via upstreams mapping", cert_name_str, server_name); + return Some(ctx.clone()); + } else { + // Certificate specified in upstreams.yaml but doesn't exist - use default instead of searching further + log::warn!("Certificate '{}' specified in upstreams config for hostname '{}' not found in cert_name_map. Available certificates: {:?}. Will use default certificate (NOT searching for wildcards).", + cert_name_str, server_name, + self.cert_name_map.iter().map(|e| e.key().clone()).collect::>()); + return None; // Return None to use default certificate - DO NOT continue searching + } + } else { + log::debug!("No upstreams mapping found for hostname: {}, will search for exact/wildcard matches", server_name); + } + + // Then, try exact match in name_map (from certificate CN/SAN) + if let Some(ctx) = self.name_map.get(server_name) { + log::info!("Found certificate via CN/SAN exact match for: {}", server_name); + return Some(ctx.clone()); + } + + // Try wildcard match from certificate CN/SAN before falling back to default + for config in &self.configs { + for name in &config.common_names { + if name.starts_with("*.") && server_name.ends_with(&name[1..]) { + log::info!("Found certificate via CN wildcard match: {} matches {}", server_name, name); + return Some(config.ssl_context.clone()); + } + } + for name in &config.alt_names { + if name.starts_with("*.") && server_name.ends_with(&name[1..]) { + log::info!("Found certificate via SAN wildcard match: {} matches {}", server_name, name); + return Some(config.ssl_context.clone()); + } + } + } + + // Check if default certificate is configured - use it as fallback + let default_cert_name = std::path::Path::new(&self.default_cert_path) + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("default"); + + // If default certificate exists and is configured, use it as fallback + if self.cert_name_map.contains_key(default_cert_name) { + log::info!("No exact or wildcard match found for '{}', will use default certificate '{}'", server_name, default_cert_name); + return None; // Return None to use default certificate + } + + log::warn!("No matching certificate found for hostname: {}, will use default certificate", server_name); + None + } + + pub fn server_name_callback(&self, ssl_ref: &mut SslRef, _ssl_alert: &mut SslAlert) -> Result<(), SniError> { + let server_name_opt = ssl_ref.servername(NameType::HOST_NAME); + log::info!("TLS server_name_callback invoked: server_name = {:?}", server_name_opt); + if let Some(name) = server_name_opt { + let name_str = name.to_string(); + log::info!("SNI callback: Looking up certificate for hostname: {}", name_str); + log::debug!("SNI callback: upstreams_cert_map has {} entries", self.upstreams_cert_map.len()); + log::debug!("SNI callback: cert_name_map has {} entries", self.cert_name_map.len()); + + match self.find_ssl_context(&name_str) { + Some(ctx) => { + log::info!("SNI callback: Found matching certificate for hostname: {}", name_str); + log::info!("SNI callback: Setting SSL context for hostname: {}", name_str); + ssl_ref.set_ssl_context(&*ctx).map_err(|e| { + log::error!("SNI callback: Failed to set SSL context for hostname {}: {:?}", name_str, e); + SniError::ALERT_FATAL + })?; + log::info!("SNI callback: Successfully set SSL context for hostname: {}", name_str); + } + None => { + log::warn!("SNI callback: No matching certificate found for hostname: {}, using default certificate", name_str); + log::debug!("SNI callback: Available upstreams mappings: {:?}", + self.upstreams_cert_map.iter().map(|e| (e.key().clone(), e.value().clone())).collect::>()); + log::debug!("SNI callback: Available certificate names: {:?}", + self.cert_name_map.iter().map(|e| e.key().clone()).collect::>()); + // Don't set a context - let it use the default + } + } + } else { + log::debug!("SNI callback: No server name (SNI) provided in TLS handshake"); + } + Ok(()) + } + + /// Get certificate path for a given hostname + pub fn get_cert_path_for_hostname(&self, hostname: &str) -> Option { + // First try exact match + if self.name_map.contains_key(hostname) { + // Find the certificate info that matches this hostname + for config in &self.configs { + if config.common_names.contains(&hostname.to_string()) || config.alt_names.contains(&hostname.to_string()) { + return Some(config.cert_path.clone()); + } + } + } + + // Try wildcard match + for config in &self.configs { + for name in &config.common_names { + if name.starts_with("*.") && hostname.ends_with(&name[1..]) { + return Some(config.cert_path.clone()); + } + } + for name in &config.alt_names { + if name.starts_with("*.") && hostname.ends_with(&name[1..]) { + return Some(config.cert_path.clone()); + } + } + } + + // Return default certificate path if no match found + Some(self.default_cert_path.clone()) + } +} + +fn load_cert_info(cert_path: &str, key_path: &str, _grade: &str) -> Option { + let mut common_names = HashSet::new(); + let mut alt_names = HashSet::new(); + + let file = File::open(cert_path); + match file { + Err(e) => { + log::error!("Failed to open certificate file: {:?}", e); + return None; + } + Ok(file) => { + let mut reader = BufReader::new(file); + match read_one(&mut reader) { + Err(e) => { + log::error!("Failed to decode PEM from certificate file: {:?}", e); + return None; + } + Ok(leaf) => match leaf { + Some(Item::X509Certificate(cert)) => match X509Certificate::from_der(&cert) { + Err(NomErr::Error(e)) | Err(NomErr::Failure(e)) => { + log::error!("Failed to parse certificate: {:?}", e); + return None; + } + Err(_) => { + log::error!("Unknown error while parsing certificate"); + return None; + } + Ok((_, x509)) => { + let subject = x509.subject(); + for attr in subject.iter_common_name() { + if let Ok(cn) = attr.as_str() { + common_names.insert(cn.to_string()); + } + } + + if let Ok(Some(san)) = x509.subject_alternative_name() { + for name in san.value.general_names.iter() { + if let GeneralName::DNSName(dns) = name { + let dns_string = dns.to_string(); + if !common_names.contains(&dns_string) { + alt_names.insert(dns_string); + } + } + } + } + } + }, + _ => { + log::error!("Failed to read certificate"); + return None; + } + }, + } + } + } + + match create_ssl_context(cert_path, key_path) { + Ok(ssl_context) => { + Some(CertificateInfo { + common_names: common_names.into_iter().collect(), + alt_names: alt_names.into_iter().collect(), + ssl_context, + cert_path: cert_path.to_string(), + key_path: key_path.to_string(), + }) + } + Err(e) => { + log::error!("Failed to create SSL context from cert paths '{}' and '{}': {}", cert_path, key_path, e); + None + } + } +} + +fn create_ssl_context(cert_path: &str, key_path: &str) -> Result> { + // Always try to use global certificates for SNI callback + // This ensures that even contexts created without explicit certificates + // will have the SNI callback set if global certificates are available + create_ssl_context_with_sni_callback(cert_path, key_path, None) +} + +fn create_ssl_context_with_sni_callback( + cert_path: &str, + key_path: &str, + certificates: Option>, +) -> Result> { + let mut ctx = SslContext::builder(SslMethod::tls()) + .map_err(|e| format!("Failed to create SSL context builder: {}", e))?; + + ctx.set_certificate_chain_file(cert_path) + .map_err(|e| format!("Failed to set certificate chain file '{}': {}", cert_path, e))?; + + ctx.set_private_key_file(key_path, SslFiletype::PEM) + .map_err(|e| format!("Failed to set private key file '{}': {}", key_path, e))?; + + ctx.set_alpn_select_callback(prefer_h2); + + // Set SNI callback - use provided certificates or global certificates + let certs_for_callback = certificates.or_else(get_global_certificates); + if let Some(certs) = certs_for_callback { + let certs_clone = certs.clone(); + ctx.set_servername_callback(move |ssl_ref: &mut SslRef, _ssl_alert: &mut SslAlert| -> Result<(), SniError> { + certs_clone.server_name_callback(ssl_ref, _ssl_alert) + }); + log::debug!("Set SNI callback on SSL context for certificate selection"); + } else { + // Certificates may not be loaded yet (e.g., during startup before Redis certificates are fetched) + // This is expected during initialization, so use debug level instead of warn + static WARNED: std::sync::Once = std::sync::Once::new(); + WARNED.call_once(|| { + log::debug!("No certificates available for SNI callback yet - certificates will be loaded asynchronously. Certificate selection by hostname will work once certificates are loaded."); + }); + } + + let built = ctx.build(); + + Ok(built) +} + +#[derive(Debug)] +pub struct CipherSuite { + pub high: &'static str, + pub medium: &'static str, + pub legacy: &'static str, +} +const CIPHERS: CipherSuite = CipherSuite { + high: "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256:TLS_AES_128_GCM_SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305", + medium: "ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES128-SHA:AES128-GCM-SHA256", + legacy: "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH", +}; + +#[derive(Debug)] +pub enum TlsGrade { + HIGH, + MEDIUM, + LEGACY, +} + +impl TlsGrade { + pub fn from_str(s: &str) -> Option { + match s.to_ascii_lowercase().as_str() { + "high" => Some(TlsGrade::HIGH), + "medium" => Some(TlsGrade::MEDIUM), + "unsafe" => Some(TlsGrade::LEGACY), + _ => None, + } + } +} +pub fn prefer_h2<'a>(_ssl: &mut SslRef, alpn_in: &'a [u8]) -> Result<&'a [u8], AlpnError> { + match select_next_proto("\x02h2\x08http/1.1".as_bytes(), alpn_in) { + Some(p) => Ok(p), + _ => Err(AlpnError::NOACK), + } +} + +// Helper to set ALPN on TlsSettings +pub fn set_alpn_prefer_h2(tls_settings: &mut pingora_core::listeners::tls::TlsSettings) { + use pingora_core::listeners::ALPN; + tls_settings.set_alpn(ALPN::H2H1); +} + +// Helper to create TlsSettings with SNI callback for certificate selection +// This uses TlsSettings::with_callbacks() which allows us to provide a TlsAccept implementation +// that handles dynamic certificate selection based on SNI (Server Name Indication) +pub fn create_tls_settings_with_sni( + cert_path: &str, + key_path: &str, + grade: &str, + certificates: Option>, +) -> Result> { + // Get the certificates - use provided or fall back to global + let certs = certificates + .or_else(get_global_certificates) + .ok_or_else(|| "No certificates available for TLS configuration".to_string())?; + + log::info!("Creating TlsSettings with callbacks for dynamic certificate selection"); + log::info!("Default certificate: {} / {}", cert_path, key_path); + log::info!("Certificate mappings: {} upstreams, {} certificates", + certs.upstreams_cert_map.len(), certs.cert_name_map.len()); + + // Use TlsSettings::with_callbacks() instead of TlsSettings::intermediate() + // This allows us to provide our Certificates struct which implements TlsAccept + // The certificate_callback method will be called during TLS handshake to select + // the appropriate certificate based on the SNI hostname + // + // Note: with_callbacks expects a Box + // We clone the Certificates struct to create a new instance for the callback + let tls_accept: Box = Box::new((*certs).clone()); + let mut tls_settings = TlsSettings::with_callbacks(tls_accept) + .map_err(|e| format!("Failed to create TlsSettings with callbacks: {}", e))?; + + // Configure TLS grade and ALPN + set_tsl_grade(&mut tls_settings, grade); + set_alpn_prefer_h2(&mut tls_settings); + + log::info!("Successfully created TlsSettings with SNI-based certificate selection"); + log::info!("Certificate selection will work based on hostname from SNI"); + + Ok(tls_settings) +} + +pub fn set_tsl_grade(tls_settings: &mut TlsSettings, grade: &str) { + let config_grade = TlsGrade::from_str(grade); + match config_grade { + Some(TlsGrade::HIGH) => { + let _ = tls_settings.set_min_proto_version(Some(SslVersion::TLS1_2)); + // let _ = tls_settings.set_max_proto_version(Some(SslVersion::TLS1_3)); + let _ = tls_settings.set_cipher_list(CIPHERS.high); + let _ = tls_settings.set_ciphersuites(CIPHERS.high); + info!("TLS grade: => HIGH"); + } + Some(TlsGrade::MEDIUM) => { + let _ = tls_settings.set_min_proto_version(Some(SslVersion::TLS1)); + let _ = tls_settings.set_cipher_list(CIPHERS.medium); + let _ = tls_settings.set_ciphersuites(CIPHERS.medium); + info!("TLS grade: => MEDIUM"); + } + Some(TlsGrade::LEGACY) => { + let _ = tls_settings.set_min_proto_version(Some(SslVersion::SSL3)); + let _ = tls_settings.set_cipher_list(CIPHERS.legacy); + let _ = tls_settings.set_ciphersuites(CIPHERS.legacy); + warn!("TLS grade: => UNSAFE"); + } + None => { + // Defaults to MEDIUM + let _ = tls_settings.set_min_proto_version(Some(SslVersion::TLS1)); + let _ = tls_settings.set_cipher_list(CIPHERS.medium); + let _ = tls_settings.set_ciphersuites(CIPHERS.medium); + warn!("TLS grade is not detected defaulting top MEDIUM"); + } + } +} + +/// Extract server certificate information for access logging +pub fn extract_cert_info(cert_path: &str) -> Option { + use sha2::{Digest, Sha256}; + + let file = File::open(cert_path).ok()?; + let mut reader = BufReader::new(file); + + // Read the first certificate from the PEM file + let item = read_one(&mut reader).ok()??; + + let cert_der = match item { + Item::X509Certificate(der) => der, + _ => return None, + }; + + // Parse the X.509 certificate + let (_, cert) = X509Certificate::from_der(&cert_der).ok()?; + + // Extract issuer + let issuer = cert.issuer().to_string(); + + // Extract subject + let subject = cert.subject().to_string(); + + // Extract validity dates (as ISO 8601 format) + let not_before = cert.validity().not_before.to_datetime().to_string(); + let not_after = cert.validity().not_after.to_datetime().to_string(); + + // Calculate SHA256 fingerprint + let mut hasher = Sha256::new(); + hasher.update(&cert_der); + let fingerprint_sha256 = format!("{:x}", hasher.finalize()); + + Some(crate::access_log::ServerCertInfo { + issuer, + subject, + not_before, + not_after, + fingerprint_sha256, + }) +} + diff --git a/src/utils/tls_client_hello.rs b/src/utils/tls_client_hello.rs index b6b2366..e7bd39e 100644 --- a/src/utils/tls_client_hello.rs +++ b/src/utils/tls_client_hello.rs @@ -1,224 +1,224 @@ -use pingora_core::protocols::ClientHelloWrapper; -use crate::utils::tls_fingerprint::Fingerprint; -use log::{debug, warn}; -use std::sync::Arc; -use std::collections::HashMap; -use std::sync::Mutex; -use std::net::SocketAddr; -use std::sync::OnceLock; -use std::time::{SystemTime, UNIX_EPOCH}; - -/// TLS fingerprint entry with timestamp for fallback matching -#[derive(Clone)] -pub struct FingerprintEntry { - pub fingerprint: Arc, - pub stored_at: SystemTime, -} - -/// Global storage for TLS fingerprints keyed by connection peer address -/// This is a temporary storage until the fingerprint can be moved to session context -static TLS_FINGERPRINTS: OnceLock>> = OnceLock::new(); - -fn get_fingerprint_map() -> &'static Mutex> { - TLS_FINGERPRINTS.get_or_init(|| Mutex::new(HashMap::new())) -} - -/// Public function to access the fingerprint map -/// This is used by tls_acceptor_wrapper to store fingerprints -pub fn get_fingerprint_map_public() -> &'static Mutex> { - get_fingerprint_map() -} - -/// Generate JA4 fingerprint from ClientHello raw bytes -/// This is called after ClientHello is extracted by ClientHelloWrapper -pub fn generate_fingerprint_from_client_hello( - hello: &pingora_core::protocols::tls::client_hello::ClientHello, - peer_addr: Option, -) -> Option> { - let peer_addr_str = peer_addr.as_ref() - .and_then(|a| a.as_inet()) - .map(|inet| format!("{}:{}", inet.ip(), inet.port())) - .unwrap_or_else(|| "unknown".to_string()); - - debug!("Generating fingerprint from ClientHello: Peer: {}, SNI={:?}, ALPN={:?}, raw_len={}", - peer_addr_str, hello.sni, hello.alpn, hello.raw.len()); - - // Generate JA4 fingerprint from raw ClientHello bytes - if let Some(mut fingerprint) = crate::utils::tls_fingerprint::fingerprint_client_hello(&hello.raw) { - // Always prefer SNI and ALPN from Pingora's parsed ClientHello if available - // Pingora's parsing is more reliable than raw bytes parsing - if hello.sni.is_some() { - fingerprint.sni = hello.sni.clone(); - } - - // ALPN: Pingora returns Vec, use first one if available - if !hello.alpn.is_empty() { - fingerprint.alpn = hello.alpn.first().cloned(); - } - - let fingerprint_arc: Arc = Arc::new(fingerprint); - - // Store fingerprint temporarily if we have peer address - // Convert pingora SocketAddr to std::net::SocketAddr for storage - if let Some(ref addr) = peer_addr { - if let Some(inet) = addr.as_inet() { - let std_addr = SocketAddr::new(inet.ip().into(), inet.port()); - let key = format!("{}", std_addr); - if let Ok(mut map) = get_fingerprint_map().lock() { - let stored_at = SystemTime::now(); - let entry = FingerprintEntry { - fingerprint: fingerprint_arc.clone(), - stored_at, - }; - map.insert(key, entry); - debug!("Stored TLS fingerprint for {} at {:?}", std_addr, stored_at); - } - } - } - - // Log fingerprint details at info level - debug!( - "TLS Fingerprint extracted - Peer: {}, JA4: {}, JA4_Raw: {}, JA4_Unsorted: {}, JA4_Raw_Unsorted: {}, TLS_Version: {}, Cipher: {:?}, SNI: {:?}, ALPN: {:?}", - peer_addr_str, - fingerprint_arc.ja4, - fingerprint_arc.ja4_raw, - fingerprint_arc.ja4_unsorted, - fingerprint_arc.ja4_raw_unsorted, - fingerprint_arc.tls_version, - fingerprint_arc.cipher_suite, - fingerprint_arc.sni, - fingerprint_arc.alpn - ); - - debug!("Generated JA4 fingerprint: {}", fingerprint_arc.ja4); - return Some(fingerprint_arc); - } - - debug!("Failed to generate fingerprint from ClientHello: Peer: {}, raw_len={}", peer_addr_str, hello.raw.len()); - None -} - -/// Extract ClientHello from a stream and generate JA4 fingerprint -/// Returns the fingerprint if extraction was successful -/// The stream should be wrapped with ClientHelloWrapper before TLS handshake -#[cfg(unix)] -pub fn extract_and_fingerprint( - stream: S, - peer_addr: Option, -) -> Option> { - let mut wrapper = ClientHelloWrapper::new(stream); - - match wrapper.extract_client_hello() { - Ok(Some(hello)) => { - // Convert std::net::SocketAddr to pingora SocketAddr - use pingora_core::protocols::l4::socket::SocketAddr as PingoraAddr; - let pingora_addr = peer_addr.map(|addr| PingoraAddr::Inet(addr)); - generate_fingerprint_from_client_hello(&hello, pingora_addr) - } - Ok(None) => { - debug!("No ClientHello detected in stream"); - None - } - Err(e) => { - debug!("Failed to extract ClientHello: {:?}", e); - None - } - } -} - -/// Get stored TLS fingerprint for a peer address -pub fn get_fingerprint(peer_addr: &SocketAddr) -> Option> { - let key = format!("{}", peer_addr); - if let Ok(map) = get_fingerprint_map().lock() { - map.get(&key).map(|entry| entry.fingerprint.clone()) - } else { - None - } -} - -/// Get stored TLS fingerprint with fallback strategies for PROXY protocol -/// This tries multiple lookup strategies to handle cases where PROXY protocol -/// might cause address mismatches between storage and retrieval -pub fn get_fingerprint_with_fallback(peer_addr: &SocketAddr) -> Option> { - // First try the exact address match - if let Some(fp) = get_fingerprint(peer_addr) { - debug!("Found TLS fingerprint with exact address match: {}", peer_addr); - return Some(fp); - } - - // If not found, try to find fingerprints with matching IP (in case port differs) - // This helps when PROXY protocol causes port mismatches or when ClientHello - // callback receives a different address than session.client_addr() - if let Ok(map) = get_fingerprint_map().lock() { - let peer_ip = peer_addr.ip(); - let mut matching_entries: Vec<(SocketAddr, FingerprintEntry)> = Vec::new(); - - for (key, entry) in map.iter() { - if let Ok(addr) = key.parse::() { - if addr.ip() == peer_ip { - matching_entries.push((addr, entry.clone())); - } - } - } - - match matching_entries.len() { - 0 => { - debug!("No TLS fingerprint found for IP {} (exact match failed, no IP matches)", peer_ip); - } - 1 => { - let (matched_addr, entry) = &matching_entries[0]; - debug!("Found TLS fingerprint with matching IP but different port: {} -> {} (fallback lookup)", peer_addr, matched_addr); - return Some(entry.fingerprint.clone()); - } - _ => { - // Multiple matches - use the most recent one (most likely to be the correct connection) - // This handles cases where PROXY protocol causes address mismatches - let (matched_addr, entry) = matching_entries.iter() - .max_by_key(|(_, e)| e.stored_at) - .unwrap(); - - warn!("Multiple TLS fingerprints found for IP {} ({} matches), using most recent from {} (stored at {:?})", - peer_ip, matching_entries.len(), matched_addr, entry.stored_at); - return Some(entry.fingerprint.clone()); - } - } - } - - None -} - -/// Remove stored TLS fingerprint for a peer address -pub fn remove_fingerprint(peer_addr: &SocketAddr) { - let key = format!("{}", peer_addr); - if let Ok(mut map) = get_fingerprint_map().lock() { - map.remove(&key); - } -} - -/// Clean up old fingerprints (older than 5 minutes) to prevent memory leaks -/// This should be called periodically -pub fn cleanup_old_fingerprints() { - let cutoff = SystemTime::now().checked_sub(std::time::Duration::from_secs(300)) - .unwrap_or(UNIX_EPOCH); - - if let Ok(mut map) = get_fingerprint_map().lock() { - let initial_len = map.len(); - map.retain(|_, entry| entry.stored_at > cutoff); - let removed = initial_len - map.len(); - if removed > 0 { - debug!("Cleaned up {} old TLS fingerprints (kept {} active)", removed, map.len()); - } - } -} - -#[cfg(not(unix))] -pub fn extract_and_fingerprint( - _stream: S, - _peer_addr: Option, -) -> Option> { - // ClientHello extraction is only supported on Unix - None -} - - - +use pingora_core::protocols::ClientHelloWrapper; +use crate::utils::tls_fingerprint::Fingerprint; +use log::{debug, warn}; +use std::sync::Arc; +use std::collections::HashMap; +use std::sync::Mutex; +use std::net::SocketAddr; +use std::sync::OnceLock; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// TLS fingerprint entry with timestamp for fallback matching +#[derive(Clone)] +pub struct FingerprintEntry { + pub fingerprint: Arc, + pub stored_at: SystemTime, +} + +/// Global storage for TLS fingerprints keyed by connection peer address +/// This is a temporary storage until the fingerprint can be moved to session context +static TLS_FINGERPRINTS: OnceLock>> = OnceLock::new(); + +fn get_fingerprint_map() -> &'static Mutex> { + TLS_FINGERPRINTS.get_or_init(|| Mutex::new(HashMap::new())) +} + +/// Public function to access the fingerprint map +/// This is used by tls_acceptor_wrapper to store fingerprints +pub fn get_fingerprint_map_public() -> &'static Mutex> { + get_fingerprint_map() +} + +/// Generate JA4 fingerprint from ClientHello raw bytes +/// This is called after ClientHello is extracted by ClientHelloWrapper +pub fn generate_fingerprint_from_client_hello( + hello: &pingora_core::protocols::tls::client_hello::ClientHello, + peer_addr: Option, +) -> Option> { + let peer_addr_str = peer_addr.as_ref() + .and_then(|a| a.as_inet()) + .map(|inet| format!("{}:{}", inet.ip(), inet.port())) + .unwrap_or_else(|| "unknown".to_string()); + + debug!("Generating fingerprint from ClientHello: Peer: {}, SNI={:?}, ALPN={:?}, raw_len={}", + peer_addr_str, hello.sni, hello.alpn, hello.raw.len()); + + // Generate JA4 fingerprint from raw ClientHello bytes + if let Some(mut fingerprint) = crate::utils::tls_fingerprint::fingerprint_client_hello(&hello.raw) { + // Always prefer SNI and ALPN from Pingora's parsed ClientHello if available + // Pingora's parsing is more reliable than raw bytes parsing + if hello.sni.is_some() { + fingerprint.sni = hello.sni.clone(); + } + + // ALPN: Pingora returns Vec, use first one if available + if !hello.alpn.is_empty() { + fingerprint.alpn = hello.alpn.first().cloned(); + } + + let fingerprint_arc: Arc = Arc::new(fingerprint); + + // Store fingerprint temporarily if we have peer address + // Convert pingora SocketAddr to std::net::SocketAddr for storage + if let Some(ref addr) = peer_addr { + if let Some(inet) = addr.as_inet() { + let std_addr = SocketAddr::new(inet.ip().into(), inet.port()); + let key = format!("{}", std_addr); + if let Ok(mut map) = get_fingerprint_map().lock() { + let stored_at = SystemTime::now(); + let entry = FingerprintEntry { + fingerprint: fingerprint_arc.clone(), + stored_at, + }; + map.insert(key, entry); + debug!("Stored TLS fingerprint for {} at {:?}", std_addr, stored_at); + } + } + } + + // Log fingerprint details at info level + debug!( + "TLS Fingerprint extracted - Peer: {}, JA4: {}, JA4_Raw: {}, JA4_Unsorted: {}, JA4_Raw_Unsorted: {}, TLS_Version: {}, Cipher: {:?}, SNI: {:?}, ALPN: {:?}", + peer_addr_str, + fingerprint_arc.ja4, + fingerprint_arc.ja4_raw, + fingerprint_arc.ja4_unsorted, + fingerprint_arc.ja4_raw_unsorted, + fingerprint_arc.tls_version, + fingerprint_arc.cipher_suite, + fingerprint_arc.sni, + fingerprint_arc.alpn + ); + + debug!("Generated JA4 fingerprint: {}", fingerprint_arc.ja4); + return Some(fingerprint_arc); + } + + debug!("Failed to generate fingerprint from ClientHello: Peer: {}, raw_len={}", peer_addr_str, hello.raw.len()); + None +} + +/// Extract ClientHello from a stream and generate JA4 fingerprint +/// Returns the fingerprint if extraction was successful +/// The stream should be wrapped with ClientHelloWrapper before TLS handshake +#[cfg(unix)] +pub fn extract_and_fingerprint( + stream: S, + peer_addr: Option, +) -> Option> { + let mut wrapper = ClientHelloWrapper::new(stream); + + match wrapper.extract_client_hello() { + Ok(Some(hello)) => { + // Convert std::net::SocketAddr to pingora SocketAddr + use pingora_core::protocols::l4::socket::SocketAddr as PingoraAddr; + let pingora_addr = peer_addr.map(|addr| PingoraAddr::Inet(addr)); + generate_fingerprint_from_client_hello(&hello, pingora_addr) + } + Ok(None) => { + debug!("No ClientHello detected in stream"); + None + } + Err(e) => { + debug!("Failed to extract ClientHello: {:?}", e); + None + } + } +} + +/// Get stored TLS fingerprint for a peer address +pub fn get_fingerprint(peer_addr: &SocketAddr) -> Option> { + let key = format!("{}", peer_addr); + if let Ok(map) = get_fingerprint_map().lock() { + map.get(&key).map(|entry| entry.fingerprint.clone()) + } else { + None + } +} + +/// Get stored TLS fingerprint with fallback strategies for PROXY protocol +/// This tries multiple lookup strategies to handle cases where PROXY protocol +/// might cause address mismatches between storage and retrieval +pub fn get_fingerprint_with_fallback(peer_addr: &SocketAddr) -> Option> { + // First try the exact address match + if let Some(fp) = get_fingerprint(peer_addr) { + debug!("Found TLS fingerprint with exact address match: {}", peer_addr); + return Some(fp); + } + + // If not found, try to find fingerprints with matching IP (in case port differs) + // This helps when PROXY protocol causes port mismatches or when ClientHello + // callback receives a different address than session.client_addr() + if let Ok(map) = get_fingerprint_map().lock() { + let peer_ip = peer_addr.ip(); + let mut matching_entries: Vec<(SocketAddr, FingerprintEntry)> = Vec::new(); + + for (key, entry) in map.iter() { + if let Ok(addr) = key.parse::() { + if addr.ip() == peer_ip { + matching_entries.push((addr, entry.clone())); + } + } + } + + match matching_entries.len() { + 0 => { + debug!("No TLS fingerprint found for IP {} (exact match failed, no IP matches)", peer_ip); + } + 1 => { + let (matched_addr, entry) = &matching_entries[0]; + debug!("Found TLS fingerprint with matching IP but different port: {} -> {} (fallback lookup)", peer_addr, matched_addr); + return Some(entry.fingerprint.clone()); + } + _ => { + // Multiple matches - use the most recent one (most likely to be the correct connection) + // This handles cases where PROXY protocol causes address mismatches + let (matched_addr, entry) = matching_entries.iter() + .max_by_key(|(_, e)| e.stored_at) + .unwrap(); + + warn!("Multiple TLS fingerprints found for IP {} ({} matches), using most recent from {} (stored at {:?})", + peer_ip, matching_entries.len(), matched_addr, entry.stored_at); + return Some(entry.fingerprint.clone()); + } + } + } + + None +} + +/// Remove stored TLS fingerprint for a peer address +pub fn remove_fingerprint(peer_addr: &SocketAddr) { + let key = format!("{}", peer_addr); + if let Ok(mut map) = get_fingerprint_map().lock() { + map.remove(&key); + } +} + +/// Clean up old fingerprints (older than 5 minutes) to prevent memory leaks +/// This should be called periodically +pub fn cleanup_old_fingerprints() { + let cutoff = SystemTime::now().checked_sub(std::time::Duration::from_secs(300)) + .unwrap_or(UNIX_EPOCH); + + if let Ok(mut map) = get_fingerprint_map().lock() { + let initial_len = map.len(); + map.retain(|_, entry| entry.stored_at > cutoff); + let removed = initial_len - map.len(); + if removed > 0 { + debug!("Cleaned up {} old TLS fingerprints (kept {} active)", removed, map.len()); + } + } +} + +#[cfg(not(unix))] +pub fn extract_and_fingerprint( + _stream: S, + _peer_addr: Option, +) -> Option> { + // ClientHello extraction is only supported on Unix + None +} + + + diff --git a/src/utils/tls_fingerprint.rs b/src/utils/tls_fingerprint.rs index a8f159e..25a7c49 100644 --- a/src/utils/tls_fingerprint.rs +++ b/src/utils/tls_fingerprint.rs @@ -1,336 +1,336 @@ -use sha2::{Digest, Sha256}; -use tls_parser::{ - TlsClientHelloContents, TlsExtension, TlsExtensionType, TlsMessage, TlsMessageHandshake, - parse_tls_extensions, parse_tls_plaintext, -}; - -/// GREASE values as defined in RFC 8701. -pub const TLS_GREASE_VALUES: [u16; 16] = [ - 0x0a0a, 0x1a1a, 0x2a2a, 0x3a3a, 0x4a4a, 0x5a5a, 0x6a6a, 0x7a7a, 0x8a8a, 0x9a9a, 0xaaaa, 0xbaba, - 0xcaca, 0xdada, 0xeaea, 0xfafa, -]; - -/// High level JA4 fingerprint summary derived from a TLS ClientHello. -#[derive(Debug, Clone)] -pub struct Fingerprint { - pub ja4: String, - pub ja4_raw: String, - pub ja4_unsorted: String, - pub ja4_raw_unsorted: String, - pub tls_version: String, - pub cipher_suite: Option, - pub sni: Option, - pub alpn: Option, -} - -/// Attempt to parse a TLS ClientHello from the supplied bytes and, if successful, -/// return the corresponding JA4 fingerprints. -pub fn fingerprint_client_hello(data: &[u8]) -> Option { - let (_, record) = parse_tls_plaintext(data).ok()?; - for message in record.msg.iter() { - if let TlsMessage::Handshake(TlsMessageHandshake::ClientHello(client_hello)) = message { - let signature = extract_tls_signature_from_client_hello(client_hello).ok()?; - let sorted = signature.generate_ja4_with_order(false); - let unsorted = signature.generate_ja4_with_order(true); - return Some(Fingerprint { - ja4: sorted.full.value().to_string(), - ja4_raw: sorted.raw.value().to_string(), - ja4_unsorted: unsorted.full.value().to_string(), - ja4_raw_unsorted: unsorted.raw.value().to_string(), - tls_version: signature.version.to_string(), - cipher_suite: signature.preferred_cipher_suite.map(cipher_suite_to_string), - sni: signature.sni.clone(), - alpn: signature.alpn.clone(), - }); - } - } - None -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -enum TlsVersion { - V1_3, - V1_2, - V1_1, - V1_0, - Ssl3_0, - #[allow(dead_code)] - Ssl2_0, - Unknown(u16), -} - -impl std::fmt::Display for TlsVersion { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - TlsVersion::V1_3 => write!(f, "13"), - TlsVersion::V1_2 => write!(f, "12"), - TlsVersion::V1_1 => write!(f, "11"), - TlsVersion::V1_0 => write!(f, "10"), - TlsVersion::Ssl3_0 => write!(f, "s3"), - TlsVersion::Ssl2_0 => write!(f, "s2"), - TlsVersion::Unknown(_) => write!(f, "00"), - } - } -} - -#[derive(Debug, Clone, PartialEq)] -enum Ja4Fingerprint { - Sorted(String), - Unsorted(String), -} - -impl Ja4Fingerprint { - fn value(&self) -> &str { - match self { - Ja4Fingerprint::Sorted(v) => v, - Ja4Fingerprint::Unsorted(v) => v, - } - } -} - -#[derive(Debug, Clone, PartialEq)] -enum Ja4RawFingerprint { - Sorted(String), - Unsorted(String), -} - -impl Ja4RawFingerprint { - fn value(&self) -> &str { - match self { - Ja4RawFingerprint::Sorted(v) => v, - Ja4RawFingerprint::Unsorted(v) => v, - } - } -} - -#[derive(Debug, Clone, PartialEq)] -struct Ja4Payload { - full: Ja4Fingerprint, - raw: Ja4RawFingerprint, -} - -#[derive(Debug, Clone, PartialEq)] -struct Signature { - version: TlsVersion, - cipher_suites: Vec, - preferred_cipher_suite: Option, - extensions: Vec, - elliptic_curves: Vec, - elliptic_curve_point_formats: Vec, - signature_algorithms: Vec, - sni: Option, - alpn: Option, -} - -impl Signature { - fn generate_ja4_with_order(&self, original_order: bool) -> Ja4Payload { - let filtered_ciphers = filter_grease_values(&self.cipher_suites); - let filtered_extensions = filter_grease_values(&self.extensions); - let filtered_sig_algs = filter_grease_values(&self.signature_algorithms); - - let protocol = "t"; - let tls_version_str = format!("{}", self.version); - let sni_indicator = if self.sni.is_some() { "d" } else { "i" }; - let cipher_count = format!("{:02}", self.cipher_suites.len().min(99)); - let extension_count = format!("{:02}", self.extensions.len().min(99)); - let (alpn_first, alpn_last) = match &self.alpn { - Some(alpn) => first_last_alpn(alpn), - None => ('0', '0'), - }; - let ja4_a = format!( - "{protocol}{tls_version_str}{sni_indicator}{cipher_count}{extension_count}{alpn_first}{alpn_last}" - ); - - let mut ciphers_for_b = filtered_ciphers; - if !original_order { - ciphers_for_b.sort_unstable(); - } - let ja4_b_raw = ciphers_for_b - .iter() - .map(|c| format!("{c:04x}")) - .collect::>() - .join(","); - - let mut extensions_for_c = filtered_extensions; - if !original_order { - extensions_for_c.retain(|&ext| ext != 0x0000 && ext != 0x0010); - extensions_for_c.sort_unstable(); - } - let extensions_str = extensions_for_c - .iter() - .map(|e| format!("{e:04x}")) - .collect::>() - .join(","); - - let sig_algs_str = filtered_sig_algs - .iter() - .map(|s| format!("{s:04x}")) - .collect::>() - .join(","); - - let ja4_c_raw = if sig_algs_str.is_empty() { - extensions_str - } else if extensions_str.is_empty() { - sig_algs_str - } else { - format!("{extensions_str}_{sig_algs_str}") - }; - - let ja4_b_hash = hash12(&ja4_b_raw); - let ja4_c_hash = hash12(&ja4_c_raw); - - let ja4_hashed = format!("{ja4_a}_{ja4_b_hash}_{ja4_c_hash}"); - let ja4_raw_full = format!("{ja4_a}_{ja4_b_raw}_{ja4_c_raw}"); - - let full = if original_order { - Ja4Fingerprint::Unsorted(ja4_hashed) - } else { - Ja4Fingerprint::Sorted(ja4_hashed) - }; - let raw = if original_order { - Ja4RawFingerprint::Unsorted(ja4_raw_full) - } else { - Ja4RawFingerprint::Sorted(ja4_raw_full) - }; - - Ja4Payload { full, raw } - } -} - -fn extract_tls_signature_from_client_hello( - client_hello: &TlsClientHelloContents, -) -> Result { - let cipher_suites: Vec = client_hello.ciphers.iter().map(|c| c.0).collect(); - // Filter out GREASE values before selecting preferred cipher suite - let filtered_cipher_suites = filter_grease_values(&cipher_suites); - let preferred_cipher_suite = filtered_cipher_suites.first().copied(); - - let mut extensions = Vec::new(); - let mut sni = None; - let mut alpn = None; - let mut signature_algorithms = Vec::new(); - let mut elliptic_curves = Vec::new(); - let mut elliptic_curve_point_formats = Vec::new(); - - if let Some(ext_data) = &client_hello.ext - && let Ok((_remaining, parsed_extensions)) = parse_tls_extensions(ext_data) - { - for extension in &parsed_extensions { - let ext_type: u16 = TlsExtensionType::from(extension).into(); - if !is_grease_value(ext_type) { - extensions.push(ext_type); - } - match extension { - TlsExtension::SNI(sni_list) => { - if let Some((_, hostname)) = sni_list.first() { - sni = String::from_utf8(hostname.to_vec()).ok(); - } - } - TlsExtension::ALPN(alpn_list) => { - if let Some(protocol) = alpn_list.first() { - alpn = String::from_utf8(protocol.to_vec()).ok(); - } - } - TlsExtension::SignatureAlgorithms(sig_algs) => { - signature_algorithms = sig_algs.clone(); - } - TlsExtension::EllipticCurves(curves) => { - elliptic_curves = curves.iter().map(|c| c.0).collect(); - } - TlsExtension::EcPointFormats(formats) => { - elliptic_curve_point_formats = formats.to_vec(); - } - _ => {} - } - } - } - - let version = determine_tls_version(&client_hello.version, &extensions); - - Ok(Signature { - version, - cipher_suites, - preferred_cipher_suite, - extensions, - elliptic_curves, - elliptic_curve_point_formats, - signature_algorithms, - sni, - alpn, - }) -} - -fn determine_tls_version( - legacy_version: &tls_parser::TlsVersion, - extensions: &[u16], -) -> TlsVersion { - if extensions.contains(&TlsExtensionType::SupportedVersions.into()) { - return TlsVersion::V1_3; - } - - match *legacy_version { - tls_parser::TlsVersion::Tls13 => TlsVersion::V1_3, - tls_parser::TlsVersion::Tls12 => TlsVersion::V1_2, - tls_parser::TlsVersion::Tls11 => TlsVersion::V1_1, - tls_parser::TlsVersion::Tls10 => TlsVersion::V1_0, - tls_parser::TlsVersion::Ssl30 => TlsVersion::Ssl3_0, - other => TlsVersion::Unknown(other.into()), - } -} - -fn is_grease_value(value: u16) -> bool { - TLS_GREASE_VALUES.contains(&value) -} - -fn filter_grease_values(values: &[u16]) -> Vec { - values - .iter() - .copied() - .filter(|v| !is_grease_value(*v)) - .collect() -} - -fn first_last_alpn(s: &str) -> (char, char) { - let replace_nonascii_with_9 = |c: char| if c.is_ascii() { c } else { '9' }; - let mut chars = s.chars(); - let first = chars.next().map(replace_nonascii_with_9).unwrap_or('0'); - let last = if s.len() == 1 { - '0' - } else { - chars - .next_back() - .map(replace_nonascii_with_9) - .unwrap_or('0') - }; - (first, last) -} - -fn hash12(input: &str) -> String { - let digest = Sha256::digest(input.as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() -} - -fn cipher_suite_to_string(cipher_suite: u16) -> String { - match cipher_suite { - 0x1301 => "TLS_AES_128_GCM_SHA256".to_string(), - 0x1302 => "TLS_AES_256_GCM_SHA384".to_string(), - 0x1303 => "TLS_CHACHA20_POLY1305_SHA256".to_string(), - 0x1304 => "TLS_AES_128_CCM_SHA256".to_string(), - 0x1305 => "TLS_AES_128_CCM_8_SHA256".to_string(), - 0xc02f => "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256".to_string(), - 0xc030 => "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384".to_string(), - 0xc02b => "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256".to_string(), - 0xc02c => "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384".to_string(), - 0xcca8 => "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256".to_string(), - 0xcca9 => "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256".to_string(), - 0x009e => "TLS_DHE_RSA_WITH_AES_128_GCM_SHA256".to_string(), - 0x009f => "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384".to_string(), - 0x0035 => "TLS_RSA_WITH_AES_128_GCM_SHA256".to_string(), - 0x0036 => "TLS_RSA_WITH_AES_256_GCM_SHA384".to_string(), - 0x002f => "TLS_RSA_WITH_AES_128_CBC_SHA".to_string(), - 0x003c => "TLS_RSA_WITH_AES_128_CBC_SHA256".to_string(), - 0x003d => "TLS_RSA_WITH_AES_256_CBC_SHA256".to_string(), - _ => format!("UNKNOWN_CIPHER_{:04x}", cipher_suite), - } -} +use sha2::{Digest, Sha256}; +use tls_parser::{ + TlsClientHelloContents, TlsExtension, TlsExtensionType, TlsMessage, TlsMessageHandshake, + parse_tls_extensions, parse_tls_plaintext, +}; + +/// GREASE values as defined in RFC 8701. +pub const TLS_GREASE_VALUES: [u16; 16] = [ + 0x0a0a, 0x1a1a, 0x2a2a, 0x3a3a, 0x4a4a, 0x5a5a, 0x6a6a, 0x7a7a, 0x8a8a, 0x9a9a, 0xaaaa, 0xbaba, + 0xcaca, 0xdada, 0xeaea, 0xfafa, +]; + +/// High level JA4 fingerprint summary derived from a TLS ClientHello. +#[derive(Debug, Clone)] +pub struct Fingerprint { + pub ja4: String, + pub ja4_raw: String, + pub ja4_unsorted: String, + pub ja4_raw_unsorted: String, + pub tls_version: String, + pub cipher_suite: Option, + pub sni: Option, + pub alpn: Option, +} + +/// Attempt to parse a TLS ClientHello from the supplied bytes and, if successful, +/// return the corresponding JA4 fingerprints. +pub fn fingerprint_client_hello(data: &[u8]) -> Option { + let (_, record) = parse_tls_plaintext(data).ok()?; + for message in record.msg.iter() { + if let TlsMessage::Handshake(TlsMessageHandshake::ClientHello(client_hello)) = message { + let signature = extract_tls_signature_from_client_hello(client_hello).ok()?; + let sorted = signature.generate_ja4_with_order(false); + let unsorted = signature.generate_ja4_with_order(true); + return Some(Fingerprint { + ja4: sorted.full.value().to_string(), + ja4_raw: sorted.raw.value().to_string(), + ja4_unsorted: unsorted.full.value().to_string(), + ja4_raw_unsorted: unsorted.raw.value().to_string(), + tls_version: signature.version.to_string(), + cipher_suite: signature.preferred_cipher_suite.map(cipher_suite_to_string), + sni: signature.sni.clone(), + alpn: signature.alpn.clone(), + }); + } + } + None +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum TlsVersion { + V1_3, + V1_2, + V1_1, + V1_0, + Ssl3_0, + #[allow(dead_code)] + Ssl2_0, + Unknown(u16), +} + +impl std::fmt::Display for TlsVersion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TlsVersion::V1_3 => write!(f, "13"), + TlsVersion::V1_2 => write!(f, "12"), + TlsVersion::V1_1 => write!(f, "11"), + TlsVersion::V1_0 => write!(f, "10"), + TlsVersion::Ssl3_0 => write!(f, "s3"), + TlsVersion::Ssl2_0 => write!(f, "s2"), + TlsVersion::Unknown(_) => write!(f, "00"), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +enum Ja4Fingerprint { + Sorted(String), + Unsorted(String), +} + +impl Ja4Fingerprint { + fn value(&self) -> &str { + match self { + Ja4Fingerprint::Sorted(v) => v, + Ja4Fingerprint::Unsorted(v) => v, + } + } +} + +#[derive(Debug, Clone, PartialEq)] +enum Ja4RawFingerprint { + Sorted(String), + Unsorted(String), +} + +impl Ja4RawFingerprint { + fn value(&self) -> &str { + match self { + Ja4RawFingerprint::Sorted(v) => v, + Ja4RawFingerprint::Unsorted(v) => v, + } + } +} + +#[derive(Debug, Clone, PartialEq)] +struct Ja4Payload { + full: Ja4Fingerprint, + raw: Ja4RawFingerprint, +} + +#[derive(Debug, Clone, PartialEq)] +struct Signature { + version: TlsVersion, + cipher_suites: Vec, + preferred_cipher_suite: Option, + extensions: Vec, + elliptic_curves: Vec, + elliptic_curve_point_formats: Vec, + signature_algorithms: Vec, + sni: Option, + alpn: Option, +} + +impl Signature { + fn generate_ja4_with_order(&self, original_order: bool) -> Ja4Payload { + let filtered_ciphers = filter_grease_values(&self.cipher_suites); + let filtered_extensions = filter_grease_values(&self.extensions); + let filtered_sig_algs = filter_grease_values(&self.signature_algorithms); + + let protocol = "t"; + let tls_version_str = format!("{}", self.version); + let sni_indicator = if self.sni.is_some() { "d" } else { "i" }; + let cipher_count = format!("{:02}", self.cipher_suites.len().min(99)); + let extension_count = format!("{:02}", self.extensions.len().min(99)); + let (alpn_first, alpn_last) = match &self.alpn { + Some(alpn) => first_last_alpn(alpn), + None => ('0', '0'), + }; + let ja4_a = format!( + "{protocol}{tls_version_str}{sni_indicator}{cipher_count}{extension_count}{alpn_first}{alpn_last}" + ); + + let mut ciphers_for_b = filtered_ciphers; + if !original_order { + ciphers_for_b.sort_unstable(); + } + let ja4_b_raw = ciphers_for_b + .iter() + .map(|c| format!("{c:04x}")) + .collect::>() + .join(","); + + let mut extensions_for_c = filtered_extensions; + if !original_order { + extensions_for_c.retain(|&ext| ext != 0x0000 && ext != 0x0010); + extensions_for_c.sort_unstable(); + } + let extensions_str = extensions_for_c + .iter() + .map(|e| format!("{e:04x}")) + .collect::>() + .join(","); + + let sig_algs_str = filtered_sig_algs + .iter() + .map(|s| format!("{s:04x}")) + .collect::>() + .join(","); + + let ja4_c_raw = if sig_algs_str.is_empty() { + extensions_str + } else if extensions_str.is_empty() { + sig_algs_str + } else { + format!("{extensions_str}_{sig_algs_str}") + }; + + let ja4_b_hash = hash12(&ja4_b_raw); + let ja4_c_hash = hash12(&ja4_c_raw); + + let ja4_hashed = format!("{ja4_a}_{ja4_b_hash}_{ja4_c_hash}"); + let ja4_raw_full = format!("{ja4_a}_{ja4_b_raw}_{ja4_c_raw}"); + + let full = if original_order { + Ja4Fingerprint::Unsorted(ja4_hashed) + } else { + Ja4Fingerprint::Sorted(ja4_hashed) + }; + let raw = if original_order { + Ja4RawFingerprint::Unsorted(ja4_raw_full) + } else { + Ja4RawFingerprint::Sorted(ja4_raw_full) + }; + + Ja4Payload { full, raw } + } +} + +fn extract_tls_signature_from_client_hello( + client_hello: &TlsClientHelloContents, +) -> Result { + let cipher_suites: Vec = client_hello.ciphers.iter().map(|c| c.0).collect(); + // Filter out GREASE values before selecting preferred cipher suite + let filtered_cipher_suites = filter_grease_values(&cipher_suites); + let preferred_cipher_suite = filtered_cipher_suites.first().copied(); + + let mut extensions = Vec::new(); + let mut sni = None; + let mut alpn = None; + let mut signature_algorithms = Vec::new(); + let mut elliptic_curves = Vec::new(); + let mut elliptic_curve_point_formats = Vec::new(); + + if let Some(ext_data) = &client_hello.ext + && let Ok((_remaining, parsed_extensions)) = parse_tls_extensions(ext_data) + { + for extension in &parsed_extensions { + let ext_type: u16 = TlsExtensionType::from(extension).into(); + if !is_grease_value(ext_type) { + extensions.push(ext_type); + } + match extension { + TlsExtension::SNI(sni_list) => { + if let Some((_, hostname)) = sni_list.first() { + sni = String::from_utf8(hostname.to_vec()).ok(); + } + } + TlsExtension::ALPN(alpn_list) => { + if let Some(protocol) = alpn_list.first() { + alpn = String::from_utf8(protocol.to_vec()).ok(); + } + } + TlsExtension::SignatureAlgorithms(sig_algs) => { + signature_algorithms = sig_algs.clone(); + } + TlsExtension::EllipticCurves(curves) => { + elliptic_curves = curves.iter().map(|c| c.0).collect(); + } + TlsExtension::EcPointFormats(formats) => { + elliptic_curve_point_formats = formats.to_vec(); + } + _ => {} + } + } + } + + let version = determine_tls_version(&client_hello.version, &extensions); + + Ok(Signature { + version, + cipher_suites, + preferred_cipher_suite, + extensions, + elliptic_curves, + elliptic_curve_point_formats, + signature_algorithms, + sni, + alpn, + }) +} + +fn determine_tls_version( + legacy_version: &tls_parser::TlsVersion, + extensions: &[u16], +) -> TlsVersion { + if extensions.contains(&TlsExtensionType::SupportedVersions.into()) { + return TlsVersion::V1_3; + } + + match *legacy_version { + tls_parser::TlsVersion::Tls13 => TlsVersion::V1_3, + tls_parser::TlsVersion::Tls12 => TlsVersion::V1_2, + tls_parser::TlsVersion::Tls11 => TlsVersion::V1_1, + tls_parser::TlsVersion::Tls10 => TlsVersion::V1_0, + tls_parser::TlsVersion::Ssl30 => TlsVersion::Ssl3_0, + other => TlsVersion::Unknown(other.into()), + } +} + +fn is_grease_value(value: u16) -> bool { + TLS_GREASE_VALUES.contains(&value) +} + +fn filter_grease_values(values: &[u16]) -> Vec { + values + .iter() + .copied() + .filter(|v| !is_grease_value(*v)) + .collect() +} + +fn first_last_alpn(s: &str) -> (char, char) { + let replace_nonascii_with_9 = |c: char| if c.is_ascii() { c } else { '9' }; + let mut chars = s.chars(); + let first = chars.next().map(replace_nonascii_with_9).unwrap_or('0'); + let last = if s.len() == 1 { + '0' + } else { + chars + .next_back() + .map(replace_nonascii_with_9) + .unwrap_or('0') + }; + (first, last) +} + +fn hash12(input: &str) -> String { + let digest = Sha256::digest(input.as_bytes()); + let hex = format!("{:x}", digest); + hex[..12].to_string() +} + +fn cipher_suite_to_string(cipher_suite: u16) -> String { + match cipher_suite { + 0x1301 => "TLS_AES_128_GCM_SHA256".to_string(), + 0x1302 => "TLS_AES_256_GCM_SHA384".to_string(), + 0x1303 => "TLS_CHACHA20_POLY1305_SHA256".to_string(), + 0x1304 => "TLS_AES_128_CCM_SHA256".to_string(), + 0x1305 => "TLS_AES_128_CCM_8_SHA256".to_string(), + 0xc02f => "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256".to_string(), + 0xc030 => "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384".to_string(), + 0xc02b => "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256".to_string(), + 0xc02c => "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384".to_string(), + 0xcca8 => "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256".to_string(), + 0xcca9 => "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256".to_string(), + 0x009e => "TLS_DHE_RSA_WITH_AES_128_GCM_SHA256".to_string(), + 0x009f => "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384".to_string(), + 0x0035 => "TLS_RSA_WITH_AES_128_GCM_SHA256".to_string(), + 0x0036 => "TLS_RSA_WITH_AES_256_GCM_SHA384".to_string(), + 0x002f => "TLS_RSA_WITH_AES_128_CBC_SHA".to_string(), + 0x003c => "TLS_RSA_WITH_AES_128_CBC_SHA256".to_string(), + 0x003d => "TLS_RSA_WITH_AES_256_CBC_SHA256".to_string(), + _ => format!("UNKNOWN_CIPHER_{:04x}", cipher_suite), + } +} diff --git a/src/utils/tools.rs b/src/utils/tools.rs index ee3e3fe..78f2ad4 100644 --- a/src/utils/tools.rs +++ b/src/utils/tools.rs @@ -1,278 +1,278 @@ -use crate::utils::structs::{InnerMap, UpstreamsDashMap, UpstreamsIdMap}; -use crate::utils::tls; -use crate::utils::tls::CertificateConfig; -use dashmap::DashMap; -use log::{debug, error, info, warn}; -use notify::{event::ModifyKind, Config, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; -use port_check::is_port_reachable; -use privdrop::PrivDrop; -use sha2::{Digest, Sha256}; -use std::collections::{HashMap, HashSet}; -use std::fmt::Write; -use std::net::SocketAddr; -use std::os::unix::fs::MetadataExt; -use std::str::FromStr; -use std::sync::atomic::AtomicUsize; -use std::sync::mpsc::{channel, Sender}; -use std::time::{Duration, Instant}; -use std::{fs, process, thread, time}; - -#[allow(dead_code)] -pub fn print_upstreams(upstreams: &UpstreamsDashMap) { - for host_entry in upstreams.iter() { - let hostname = host_entry.key(); - println!("Hostname: {}", hostname); - - for path_entry in host_entry.value().iter() { - let path = path_entry.key(); - println!(" Path: {}", path); - for f in path_entry.value().0.clone() { - println!( - " IP: {}, Port: {}, SSL: {}, H2: {}, HTTPS Proxy Enabled: {}, Rate Limit: {}", - f.address, - f.port, - f.ssl_enabled, - f.http2_enabled, - f.https_proxy_enabled, - f.rate_limit.unwrap_or(0) - ); - } - } - } -} - - -pub fn clone_dashmap(original: &UpstreamsDashMap) -> UpstreamsDashMap { - let new_map: UpstreamsDashMap = DashMap::new(); - - for outer_entry in original.iter() { - let hostname = outer_entry.key(); - let inner_map = outer_entry.value(); - - let new_inner_map = DashMap::new(); - - for inner_entry in inner_map.iter() { - let path = inner_entry.key(); - let (vec, _) = inner_entry.value(); - let new_vec = vec.clone(); - let new_counter = AtomicUsize::new(0); - new_inner_map.insert(path.clone(), (new_vec, new_counter)); - } - new_map.insert(hostname.clone(), new_inner_map); - } - new_map -} - -pub fn clone_dashmap_into(original: &UpstreamsDashMap, cloned: &UpstreamsDashMap) { - cloned.clear(); - for outer_entry in original.iter() { - let hostname = outer_entry.key(); - let inner_map = outer_entry.value(); - let new_inner_map = DashMap::new(); - for inner_entry in inner_map.iter() { - let path = inner_entry.key(); - let (vec, _) = inner_entry.value(); - let new_vec = vec.clone(); - let new_counter = AtomicUsize::new(0); - new_inner_map.insert(path.clone(), (new_vec, new_counter)); - } - cloned.insert(hostname.clone(), new_inner_map); - } -} - -pub fn compare_dashmaps(map1: &UpstreamsDashMap, map2: &UpstreamsDashMap) -> bool { - let keys1: HashSet<_> = map1.iter().map(|entry| entry.key().clone()).collect(); - let keys2: HashSet<_> = map2.iter().map(|entry| entry.key().clone()).collect(); - if keys1 != keys2 { - return false; - } - for entry1 in map1.iter() { - let hostname = entry1.key(); - let inner_map1 = entry1.value(); - let Some(inner_map2) = map2.get(hostname) else { - return false; - }; - let inner_keys1: HashSet<_> = inner_map1.iter().map(|e| e.key().clone()).collect(); - let inner_keys2: HashSet<_> = inner_map2.iter().map(|e| e.key().clone()).collect(); - if inner_keys1 != inner_keys2 { - return false; - } - for path_entry in inner_map1.iter() { - let path = path_entry.key(); - let (vec1, _counter1) = path_entry.value(); - let Some(entry2) = inner_map2.get(path) else { - return false; // Path exists in map1 but not in map2 - }; - let (vec2, _counter2) = entry2.value(); - let set1: HashSet<_> = vec1.iter().collect(); - let set2: HashSet<_> = vec2.iter().collect(); - if set1 != set2 { - return false; - } - } - } - true -} - -pub fn merge_headers(target: &DashMap>, source: &DashMap>) { - for entry in source.iter() { - let global_key = entry.key().clone(); - let global_values = entry.value().clone(); - let mut target_entry = target.entry(global_key).or_insert_with(Vec::new); - target_entry.extend(global_values); - } -} - -pub fn clone_idmap_into(original: &UpstreamsDashMap, cloned: &UpstreamsIdMap) { - cloned.clear(); - for outer_entry in original.iter() { - let inner_map = outer_entry.value(); - let new_inner_map = DashMap::new(); - for inner_entry in inner_map.iter() { - let path = inner_entry.key(); - let (vec, _) = inner_entry.value(); - let new_vec = vec.clone(); - for x in vec.iter() { - let mut id = String::new(); - write!(&mut id, "{}:{}:{}", x.address, x.port, x.ssl_enabled).unwrap(); - let mut hasher = Sha256::new(); - hasher.update(id.clone().into_bytes()); - let hash = hasher.finalize(); - let hex_hash = base16ct::lower::encode_string(&hash); - let hh = hex_hash[0..50].to_string(); - let to_add = InnerMap { - address: hh.clone(), - port: 0, - ssl_enabled: false, - http2_enabled: false, - https_proxy_enabled: false, - rate_limit: None, - healthcheck: None, - disable_access_log: false, - }; - cloned.insert(id, to_add); - cloned.insert(hh, x.to_owned()); - } - new_inner_map.insert(path.clone(), new_vec); - } - } -} - -pub fn listdir(dir: String) -> Vec { - let mut f = HashMap::new(); - let mut certificate_configs: Vec = vec![]; - let paths = fs::read_dir(dir).unwrap(); - for path in paths { - let path_str = path.unwrap().path().to_str().unwrap().to_owned(); - if path_str.ends_with(".crt") { - let name = path_str.replace(".crt", ""); - let key_path = name.clone() + ".key"; - // Only add certificate config if both cert and key files exist - // This prevents errors when .crt is created before .key during certificate writing - if std::path::Path::new(&key_path).exists() { - let mut inner = vec![]; - let domain = name.split("/").collect::>(); - inner.push(name.clone() + ".crt"); - inner.push(key_path.clone()); - f.insert(domain[domain.len() - 1].to_owned(), inner); - let y = CertificateConfig { - cert_path: name.clone() + ".crt", - key_path: key_path, - }; - certificate_configs.push(y); - } else { - debug!("Skipping certificate {} - key file does not exist yet", name); - } - } - } - for (_, v) in f.iter() { - // Double-check both files exist before adding - if std::path::Path::new(&v[0]).exists() && std::path::Path::new(&v[1]).exists() { - let y = CertificateConfig { - cert_path: v[0].clone(), - key_path: v[1].clone(), - }; - certificate_configs.push(y); - } - } - certificate_configs -} - -pub fn watch_folder(path: String, sender: Sender>) -> notify::Result<()> { - let (tx, rx) = channel(); - let mut watcher = RecommendedWatcher::new(tx, Config::default())?; - watcher.watch(path.as_ref(), RecursiveMode::Recursive)?; - info!("Watching for certificates in : {}", path); - let certificate_configs = listdir(path.clone()); - sender.send(certificate_configs)?; - let mut start = Instant::now(); - loop { - match rx.recv_timeout(Duration::from_secs(1)) { - Ok(Ok(event)) => match &event.kind { - EventKind::Modify(ModifyKind::Data(_)) | EventKind::Create(_) | EventKind::Remove(_) => { - if start.elapsed() > Duration::from_secs(1) { - start = Instant::now(); - // Add a small delay to allow both .crt and .key files to be written - // This prevents race conditions when certificates are being saved - thread::sleep(Duration::from_millis(100)); - let certificate_configs = listdir(path.clone()); - sender.send(certificate_configs)?; - info!("Certificate changed: {:?}, {:?}", event.kind, event.paths); - } - } - _ => {} - }, - Ok(Err(e)) => error!("Watch error: {:?}", e), - Err(_) => {} - } - } -} - -pub fn drop_priv(user: String, group: String, http_addr: String, tls_addr: Option) { - thread::sleep(time::Duration::from_millis(10)); - loop { - thread::sleep(time::Duration::from_millis(10)); - if is_port_reachable(http_addr.clone()) { - break; - } - } - if let Some(tls_addr) = tls_addr { - loop { - thread::sleep(time::Duration::from_millis(10)); - if is_port_reachable(tls_addr.clone()) { - break; - } - } - } - info!("Dropping ROOT privileges to: {}:{}", user, group); - if let Err(e) = PrivDrop::default().user(user).group(group).apply() { - error!("Failed to drop privileges: {}", e); - process::exit(1) - } -} - -pub fn check_priv(addr: &str) { - // Skip privilege check if address is empty or invalid (e.g., when using default config) - if addr.is_empty() { - return; - } - - let port = match SocketAddr::from_str(addr).map(|sa| sa.port()) { - Ok(p) => p, - Err(_) => { - warn!("Invalid socket address '{}', skipping privilege check", addr); - return; - } - }; - - match port < 1024 { - true => { - let meta = std::fs::metadata("/proc/self").map(|m| m.uid()).unwrap(); - if meta != 0 { - error!("Running on privileged port requires to start as ROOT"); - process::exit(1) - } - } - false => {} - } -} +use crate::utils::structs::{InnerMap, UpstreamsDashMap, UpstreamsIdMap}; +use crate::utils::tls; +use crate::utils::tls::CertificateConfig; +use dashmap::DashMap; +use log::{debug, error, info, warn}; +use notify::{event::ModifyKind, Config, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; +use port_check::is_port_reachable; +use privdrop::PrivDrop; +use sha2::{Digest, Sha256}; +use std::collections::{HashMap, HashSet}; +use std::fmt::Write; +use std::net::SocketAddr; +use std::os::unix::fs::MetadataExt; +use std::str::FromStr; +use std::sync::atomic::AtomicUsize; +use std::sync::mpsc::{channel, Sender}; +use std::time::{Duration, Instant}; +use std::{fs, process, thread, time}; + +#[allow(dead_code)] +pub fn print_upstreams(upstreams: &UpstreamsDashMap) { + for host_entry in upstreams.iter() { + let hostname = host_entry.key(); + println!("Hostname: {}", hostname); + + for path_entry in host_entry.value().iter() { + let path = path_entry.key(); + println!(" Path: {}", path); + for f in path_entry.value().0.clone() { + println!( + " IP: {}, Port: {}, SSL: {}, H2: {}, HTTPS Proxy Enabled: {}, Rate Limit: {}", + f.address, + f.port, + f.ssl_enabled, + f.http2_enabled, + f.https_proxy_enabled, + f.rate_limit.unwrap_or(0) + ); + } + } + } +} + + +pub fn clone_dashmap(original: &UpstreamsDashMap) -> UpstreamsDashMap { + let new_map: UpstreamsDashMap = DashMap::new(); + + for outer_entry in original.iter() { + let hostname = outer_entry.key(); + let inner_map = outer_entry.value(); + + let new_inner_map = DashMap::new(); + + for inner_entry in inner_map.iter() { + let path = inner_entry.key(); + let (vec, _) = inner_entry.value(); + let new_vec = vec.clone(); + let new_counter = AtomicUsize::new(0); + new_inner_map.insert(path.clone(), (new_vec, new_counter)); + } + new_map.insert(hostname.clone(), new_inner_map); + } + new_map +} + +pub fn clone_dashmap_into(original: &UpstreamsDashMap, cloned: &UpstreamsDashMap) { + cloned.clear(); + for outer_entry in original.iter() { + let hostname = outer_entry.key(); + let inner_map = outer_entry.value(); + let new_inner_map = DashMap::new(); + for inner_entry in inner_map.iter() { + let path = inner_entry.key(); + let (vec, _) = inner_entry.value(); + let new_vec = vec.clone(); + let new_counter = AtomicUsize::new(0); + new_inner_map.insert(path.clone(), (new_vec, new_counter)); + } + cloned.insert(hostname.clone(), new_inner_map); + } +} + +pub fn compare_dashmaps(map1: &UpstreamsDashMap, map2: &UpstreamsDashMap) -> bool { + let keys1: HashSet<_> = map1.iter().map(|entry| entry.key().clone()).collect(); + let keys2: HashSet<_> = map2.iter().map(|entry| entry.key().clone()).collect(); + if keys1 != keys2 { + return false; + } + for entry1 in map1.iter() { + let hostname = entry1.key(); + let inner_map1 = entry1.value(); + let Some(inner_map2) = map2.get(hostname) else { + return false; + }; + let inner_keys1: HashSet<_> = inner_map1.iter().map(|e| e.key().clone()).collect(); + let inner_keys2: HashSet<_> = inner_map2.iter().map(|e| e.key().clone()).collect(); + if inner_keys1 != inner_keys2 { + return false; + } + for path_entry in inner_map1.iter() { + let path = path_entry.key(); + let (vec1, _counter1) = path_entry.value(); + let Some(entry2) = inner_map2.get(path) else { + return false; // Path exists in map1 but not in map2 + }; + let (vec2, _counter2) = entry2.value(); + let set1: HashSet<_> = vec1.iter().collect(); + let set2: HashSet<_> = vec2.iter().collect(); + if set1 != set2 { + return false; + } + } + } + true +} + +pub fn merge_headers(target: &DashMap>, source: &DashMap>) { + for entry in source.iter() { + let global_key = entry.key().clone(); + let global_values = entry.value().clone(); + let mut target_entry = target.entry(global_key).or_insert_with(Vec::new); + target_entry.extend(global_values); + } +} + +pub fn clone_idmap_into(original: &UpstreamsDashMap, cloned: &UpstreamsIdMap) { + cloned.clear(); + for outer_entry in original.iter() { + let inner_map = outer_entry.value(); + let new_inner_map = DashMap::new(); + for inner_entry in inner_map.iter() { + let path = inner_entry.key(); + let (vec, _) = inner_entry.value(); + let new_vec = vec.clone(); + for x in vec.iter() { + let mut id = String::new(); + write!(&mut id, "{}:{}:{}", x.address, x.port, x.ssl_enabled).unwrap(); + let mut hasher = Sha256::new(); + hasher.update(id.clone().into_bytes()); + let hash = hasher.finalize(); + let hex_hash = base16ct::lower::encode_string(&hash); + let hh = hex_hash[0..50].to_string(); + let to_add = InnerMap { + address: hh.clone(), + port: 0, + ssl_enabled: false, + http2_enabled: false, + https_proxy_enabled: false, + rate_limit: None, + healthcheck: None, + disable_access_log: false, + }; + cloned.insert(id, to_add); + cloned.insert(hh, x.to_owned()); + } + new_inner_map.insert(path.clone(), new_vec); + } + } +} + +pub fn listdir(dir: String) -> Vec { + let mut f = HashMap::new(); + let mut certificate_configs: Vec = vec![]; + let paths = fs::read_dir(dir).unwrap(); + for path in paths { + let path_str = path.unwrap().path().to_str().unwrap().to_owned(); + if path_str.ends_with(".crt") { + let name = path_str.replace(".crt", ""); + let key_path = name.clone() + ".key"; + // Only add certificate config if both cert and key files exist + // This prevents errors when .crt is created before .key during certificate writing + if std::path::Path::new(&key_path).exists() { + let mut inner = vec![]; + let domain = name.split("/").collect::>(); + inner.push(name.clone() + ".crt"); + inner.push(key_path.clone()); + f.insert(domain[domain.len() - 1].to_owned(), inner); + let y = CertificateConfig { + cert_path: name.clone() + ".crt", + key_path: key_path, + }; + certificate_configs.push(y); + } else { + debug!("Skipping certificate {} - key file does not exist yet", name); + } + } + } + for (_, v) in f.iter() { + // Double-check both files exist before adding + if std::path::Path::new(&v[0]).exists() && std::path::Path::new(&v[1]).exists() { + let y = CertificateConfig { + cert_path: v[0].clone(), + key_path: v[1].clone(), + }; + certificate_configs.push(y); + } + } + certificate_configs +} + +pub fn watch_folder(path: String, sender: Sender>) -> notify::Result<()> { + let (tx, rx) = channel(); + let mut watcher = RecommendedWatcher::new(tx, Config::default())?; + watcher.watch(path.as_ref(), RecursiveMode::Recursive)?; + info!("Watching for certificates in : {}", path); + let certificate_configs = listdir(path.clone()); + sender.send(certificate_configs)?; + let mut start = Instant::now(); + loop { + match rx.recv_timeout(Duration::from_secs(1)) { + Ok(Ok(event)) => match &event.kind { + EventKind::Modify(ModifyKind::Data(_)) | EventKind::Create(_) | EventKind::Remove(_) => { + if start.elapsed() > Duration::from_secs(1) { + start = Instant::now(); + // Add a small delay to allow both .crt and .key files to be written + // This prevents race conditions when certificates are being saved + thread::sleep(Duration::from_millis(100)); + let certificate_configs = listdir(path.clone()); + sender.send(certificate_configs)?; + info!("Certificate changed: {:?}, {:?}", event.kind, event.paths); + } + } + _ => {} + }, + Ok(Err(e)) => error!("Watch error: {:?}", e), + Err(_) => {} + } + } +} + +pub fn drop_priv(user: String, group: String, http_addr: String, tls_addr: Option) { + thread::sleep(time::Duration::from_millis(10)); + loop { + thread::sleep(time::Duration::from_millis(10)); + if is_port_reachable(http_addr.clone()) { + break; + } + } + if let Some(tls_addr) = tls_addr { + loop { + thread::sleep(time::Duration::from_millis(10)); + if is_port_reachable(tls_addr.clone()) { + break; + } + } + } + info!("Dropping ROOT privileges to: {}:{}", user, group); + if let Err(e) = PrivDrop::default().user(user).group(group).apply() { + error!("Failed to drop privileges: {}", e); + process::exit(1) + } +} + +pub fn check_priv(addr: &str) { + // Skip privilege check if address is empty or invalid (e.g., when using default config) + if addr.is_empty() { + return; + } + + let port = match SocketAddr::from_str(addr).map(|sa| sa.port()) { + Ok(p) => p, + Err(_) => { + warn!("Invalid socket address '{}', skipping privilege check", addr); + return; + } + }; + + match port < 1024 { + true => { + let meta = std::fs::metadata("/proc/self").map(|m| m.uid()).unwrap(); + if meta != 0 { + error!("Running on privileged port requires to start as ROOT"); + process::exit(1) + } + } + false => {} + } +} diff --git a/src/waf/actions/captcha.rs b/src/waf/actions/captcha.rs index 74644a8..92a3232 100644 --- a/src/waf/actions/captcha.rs +++ b/src/waf/actions/captcha.rs @@ -1,1011 +1,1011 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use anyhow::{Context, Result}; -use chrono::Utc; -use redis::AsyncCommands; -use serde::{Deserialize, Serialize}; -use tokio::sync::{RwLock, OnceCell}; -use jsonwebtoken::{encode, decode, Header, Algorithm, Validation, EncodingKey, DecodingKey}; -use uuid::Uuid; - -use crate::redis::RedisManager; -use crate::http_client::get_global_reqwest_client; - -/// Captcha provider types supported by Gen0Sec -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, clap::ValueEnum)] -pub enum CaptchaProvider { - #[serde(rename = "hcaptcha")] - HCaptcha, - #[serde(rename = "recaptcha")] - ReCaptcha, - #[serde(rename = "turnstile")] - Turnstile, -} - -impl Default for CaptchaProvider { - fn default() -> Self { - CaptchaProvider::HCaptcha - } -} - -impl std::str::FromStr for CaptchaProvider { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "hcaptcha" => Ok(CaptchaProvider::HCaptcha), - "recaptcha" => Ok(CaptchaProvider::ReCaptcha), - "turnstile" => Ok(CaptchaProvider::Turnstile), - _ => Err(anyhow::anyhow!("Invalid captcha provider: {}", s)), - } - } -} - -/// Captcha validation request -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CaptchaValidationRequest { - pub response_token: String, - pub ip_address: String, - pub user_agent: Option, - pub site_key: String, - pub secret_key: String, - pub provider: CaptchaProvider, -} - -/// Captcha validation response -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CaptchaValidationResponse { - pub success: bool, - pub error_codes: Option>, - pub challenge_ts: Option, - pub hostname: Option, - pub score: Option, - pub action: Option, -} - -/// JWT Claims for captcha tokens -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CaptchaClaims { - /// Standard JWT claims - pub sub: String, // Subject (user identifier) - pub iss: String, // Issuer - pub aud: String, // Audience - pub exp: i64, // Expiration time - pub iat: i64, // Issued at - pub jti: String, // JWT ID (unique identifier) - - /// Custom captcha claims - pub ip_address: String, - pub user_agent: String, - pub ja4_fingerprint: Option, - pub captcha_provider: String, - pub captcha_validated: bool, -} - -/// Captcha token with JWT-based security -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CaptchaToken { - pub token: String, - pub claims: CaptchaClaims, -} - -/// Cached captcha validation result -#[derive(Debug, Clone)] -pub struct CachedCaptchaResult { - pub is_valid: bool, - pub expires_at: Instant, -} - -/// Captcha action configuration -#[derive(Debug, Clone)] -pub struct CaptchaConfig { - pub site_key: String, - pub secret_key: String, - pub jwt_secret: String, - pub provider: CaptchaProvider, - pub token_ttl_seconds: u64, - pub validation_cache_ttl_seconds: u64, -} - -/// Captcha client for validation and token management -pub struct CaptchaClient { - config: CaptchaConfig, - validation_cache: Arc>>, - validated_tokens: Arc>>, // JTI -> expiration time -} - -impl CaptchaClient { - pub fn new( - config: CaptchaConfig, - ) -> Self { - Self { - config, - validation_cache: Arc::new(RwLock::new(HashMap::new())), - validated_tokens: Arc::new(RwLock::new(HashMap::new())), - } - } - - /// Validate a captcha response token - pub async fn validate_captcha(&self, request: CaptchaValidationRequest) -> Result { - log::info!("Starting captcha validation for IP: {}, provider: {:?}", - request.ip_address, self.config.provider); - - // Check if captcha response is provided - if request.response_token.is_empty() { - log::warn!("No captcha response provided for IP: {}", request.ip_address); - return Ok(false); - } - - log::debug!("Captcha response token length: {}", request.response_token.len()); - - // Check validation cache first - let cache_key = format!("{}:{}", request.response_token, request.ip_address); - if let Some(cached) = self.get_validation_cache(&cache_key).await { - if cached.expires_at > Instant::now() { - log::debug!("Captcha validation for {} found in cache", request.ip_address); - return Ok(cached.is_valid); - } else { - self.remove_validation_cache(&cache_key).await; - } - } - - // Validate with provider API - let is_valid = match self.config.provider { - CaptchaProvider::HCaptcha => self.validate_hcaptcha(&request).await?, - CaptchaProvider::ReCaptcha => self.validate_recaptcha(&request).await?, - CaptchaProvider::Turnstile => self.validate_turnstile(&request).await?, - }; - - log::info!("Captcha validation result for IP {}: {}", request.ip_address, is_valid); - - // Cache the result - self.set_validation_cache(&cache_key, is_valid).await; - - Ok(is_valid) - } - - /// Generate a secure JWT captcha token - pub async fn generate_token( - &self, - ip_address: String, - user_agent: String, - ja4_fingerprint: Option, - ) -> Result { - let now = Utc::now(); - let exp = now + chrono::Duration::seconds(self.config.token_ttl_seconds as i64); - let jti = Uuid::new_v4().to_string(); - - let claims = CaptchaClaims { - sub: format!("captcha:{}", ip_address), - iss: "arxignis-synapse".to_string(), - aud: "captcha-validation".to_string(), - exp: exp.timestamp(), - iat: now.timestamp(), - jti: jti.clone(), - ip_address: ip_address.clone(), - user_agent: user_agent.clone(), - ja4_fingerprint, - captcha_provider: format!("{:?}", self.config.provider), - captcha_validated: false, - }; - - let header = Header::new(Algorithm::HS256); - let encoding_key = EncodingKey::from_secret(self.config.jwt_secret.as_bytes()); - - let token = encode(&header, &claims, &encoding_key) - .context("Failed to encode JWT token")?; - - let captcha_token = CaptchaToken { - token: token.clone(), - claims: claims.clone(), - }; - - // Store token in Redis for validation (optional, JWT is self-contained) - if let Ok(redis_manager) = RedisManager::get() { - let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), jti); - let mut redis = redis_manager.get_connection(); - let token_data = serde_json::to_string(&captcha_token) - .context("Failed to serialize captcha token")?; - - let _: () = redis - .set_ex(&key, token_data, self.config.token_ttl_seconds) - .await - .context("Failed to store captcha token in Redis")?; - } - - Ok(captcha_token) - } - - /// Validate a JWT captcha token - pub async fn validate_token(&self, token: &str, ip_address: &str, user_agent: &str) -> Result { - let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); - let mut validation = Validation::new(Algorithm::HS256); - validation.set_audience(&["captcha-validation"]); - - match decode::(token, &decoding_key, &validation) { - Ok(token_data) => { - let claims = token_data.claims; - - // Check if token is expired (JWT handles this automatically, but double-check) - let now = Utc::now().timestamp(); - if claims.exp < now { - log::debug!("JWT token expired"); - return Ok(false); - } - - // Verify IP and User-Agent binding - if claims.ip_address != ip_address || claims.user_agent != user_agent { - log::warn!("JWT token validation failed: IP or User-Agent mismatch"); - return Ok(false); - } - - // Check Redis first for updated token state - let mut captcha_validated = claims.captcha_validated; - log::debug!("Initial JWT token captcha_validated: {}", captcha_validated); - - // Check in-memory cache first (faster) - { - let validated_tokens = self.validated_tokens.read().await; - if let Some(expiration) = validated_tokens.get(&claims.jti) { - if *expiration > Instant::now() { - captcha_validated = true; - log::debug!("Found validated token JTI {} in memory cache", claims.jti); - } else { - log::debug!("Token JTI {} expired in memory cache", claims.jti); - } - } - } - - // If not found in memory cache, check Redis - if !captcha_validated { - if let Ok(redis_manager) = RedisManager::get() { - let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), claims.jti); - log::debug!("Looking up token in Redis with key: {}", key); - - let mut redis = redis_manager.get_connection(); - match redis.get::<_, String>(&key).await { - Ok(token_data_str) => { - log::debug!("Found token data in Redis: {}", token_data_str); - if let Ok(updated_token) = serde_json::from_str::(&token_data_str) { - captcha_validated = updated_token.claims.captcha_validated; - log::debug!("Updated captcha_validated from Redis: {}", captcha_validated); - - // Update memory cache if found in Redis - if captcha_validated { - let expiration = Instant::now() + Duration::from_secs(self.config.token_ttl_seconds); - let mut validated_tokens = self.validated_tokens.write().await; - validated_tokens.insert(claims.jti.clone(), expiration); - } - } else { - log::warn!("Failed to parse token data from Redis"); - } - } - Err(e) => { - log::debug!("Redis token lookup failed for JTI {}: {}", claims.jti, e); - } - } - } else { - log::debug!("Redis manager not available"); - } - } - - // Check if captcha was validated (either from JWT or Redis) - if !captcha_validated { - log::debug!("JWT token not validated for captcha"); - return Ok(false); - } - - // Optional: Check Redis blacklist for revoked tokens - if let Ok(redis_manager) = RedisManager::get() { - let blacklist_key = format!("{}:captcha_blacklist:{}", redis_manager.create_namespace("captcha"), claims.jti); - let mut redis = redis_manager.get_connection(); - match redis.exists::<_, bool>(&blacklist_key).await { - Ok(true) => { - log::debug!("JWT token {} is blacklisted", claims.jti); - return Ok(false); - } - Ok(false) => { - // Token not blacklisted, continue validation - } - Err(e) => { - log::warn!("Redis blacklist check error for JWT {}: {}", claims.jti, e); - // Continue validation despite Redis error - } - } - } - - Ok(true) - } - Err(e) => { - log::warn!("JWT token validation failed: {}", e); - Ok(false) - } - } - } - - /// Mark a JWT token as validated after successful captcha completion - pub async fn mark_token_validated(&self, token: &str) -> Result<()> { - let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); - let mut validation = Validation::new(Algorithm::HS256); - validation.set_audience(&["captcha-validation"]); - - match decode::(token, &decoding_key, &validation) { - Ok(token_data) => { - let claims = token_data.claims; - - // Store the JTI as validated in memory cache - let expiration = Instant::now() + Duration::from_secs(self.config.token_ttl_seconds); - { - let mut validated_tokens = self.validated_tokens.write().await; - validated_tokens.insert(claims.jti.clone(), expiration); - log::debug!("Marked token JTI {} as validated, expires at {:?}", claims.jti, expiration); - } - - // Also update Redis cache if available (for persistence across restarts) - if let Ok(redis_manager) = RedisManager::get() { - let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), claims.jti); - log::debug!("Storing updated token in Redis with key: {}", key); - - let mut redis = redis_manager.get_connection(); - let mut updated_claims = claims.clone(); - updated_claims.captcha_validated = true; - - let updated_captcha_token = CaptchaToken { - token: token.to_string(), - claims: updated_claims, - }; - let token_data = serde_json::to_string(&updated_captcha_token) - .context("Failed to serialize updated captcha token")?; - - log::debug!("Token data to store: {}", token_data); - - let _: () = redis - .set_ex(&key, token_data, self.config.token_ttl_seconds) - .await - .context("Failed to update captcha token in Redis")?; - - log::debug!("Successfully stored updated token in Redis"); - } else { - log::debug!("Redis manager not available for token storage"); - } - - Ok(()) - } - Err(e) => { - log::warn!("Failed to decode JWT token for validation marking: {}", e); - Err(anyhow::anyhow!("Invalid JWT token: {}", e)) - } - } - } - - /// Revoke a JWT token by adding it to blacklist - pub async fn revoke_token(&self, token: &str) -> Result<()> { - let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); - let mut validation = Validation::new(Algorithm::HS256); - validation.set_audience(&["captcha-validation"]); - - match decode::(token, &decoding_key, &validation) { - Ok(token_data) => { - let claims = token_data.claims; - - // Add to Redis blacklist - if let Ok(redis_manager) = RedisManager::get() { - let blacklist_key = format!("{}:captcha_blacklist:{}", redis_manager.create_namespace("captcha"), claims.jti); - let mut redis = redis_manager.get_connection(); - let _: () = redis - .set_ex(&blacklist_key, "revoked", self.config.token_ttl_seconds) - .await - .context("Failed to add token to blacklist")?; - } - - Ok(()) - } - Err(e) => { - log::warn!("Failed to decode JWT token for revocation: {}", e); - Err(anyhow::anyhow!("Invalid JWT token: {}", e)) - } - } - } - - /// Apply captcha challenge (return HTML form) - pub fn apply_captcha_challenge(&self, site_key: &str) -> String { - self.render_captcha_template(site_key, None) - } - - /// Apply captcha challenge with JWT token (return HTML form) - pub fn apply_captcha_challenge_with_token(&self, site_key: &str, jwt_token: &str) -> String { - self.render_captcha_template(site_key, Some(jwt_token)) - } - - /// Render captcha template based on provider - fn render_captcha_template(&self, site_key: &str, jwt_token: Option<&str>) -> String { - let (frontend_js, frontend_key, callback_attr) = match self.config.provider { - CaptchaProvider::HCaptcha => ( - "https://js.hcaptcha.com/1/api.js", - "h-captcha", - "data-callback=\"captchaCallback\"" - ), - CaptchaProvider::ReCaptcha => ( - "https://www.recaptcha.net/recaptcha/api.js", - "g-recaptcha", - "data-callback=\"captchaCallback\"" - ), - CaptchaProvider::Turnstile => ( - "https://challenges.cloudflare.com/turnstile/v0/api.js", - "cf-turnstile", - "data-callback=\"onTurnstileSuccess\" data-error-callback=\"onTurnstileError\"" - ), - }; - - let jwt_token_input = if let Some(token) = jwt_token { - format!(r#""#, token) - } else { - r#""#.to_string() - }; - - let html_template = format!( - r#" - - - Gen0Sec Captcha - - - - - - -
-
-

Gen0Sec Captcha

-

Please complete the security verification below to continue.

-
- -
-
-
-
- - {} -
- - -
- - - -"#, - frontend_js, - frontend_key, - site_key, - callback_attr, - jwt_token_input - ); - html_template - } - - /// Validate with hCaptcha API - async fn validate_hcaptcha(&self, request: &CaptchaValidationRequest) -> Result { - // Use shared HTTP client with keepalive instead of creating new client - let client = get_global_reqwest_client() - .context("Failed to get global HTTP client")?; - - let mut params = HashMap::new(); - params.insert("response", &request.response_token); - params.insert("secret", &request.secret_key); - params.insert("sitekey", &request.site_key); - params.insert("remoteip", &request.ip_address); - - log::info!("hCaptcha validation request - response_length: {}, remote_ip: {}", - request.response_token.len(), request.ip_address); - - let response = client - .post("https://hcaptcha.com/siteverify") - .form(¶ms) - .send() - .await - .context("Failed to send hCaptcha validation request")?; - - log::info!("hCaptcha validation HTTP response - status: {}", response.status()); - - if !response.status().is_success() { - log::error!("hCaptcha service returned non-success status: {}", response.status()); - return Ok(false); - } - - let validation_response: CaptchaValidationResponse = response - .json() - .await - .context("Failed to parse hCaptcha response")?; - - if !validation_response.success { - if let Some(error_codes) = &validation_response.error_codes { - for error_code in error_codes { - match error_code.as_str() { - "invalid-input-secret" => { - log::error!("hCaptcha secret key is invalid"); - return Ok(false); - } - "invalid-input-response" => { - log::info!("Invalid hCaptcha response from user"); - return Ok(false); - } - "timeout-or-duplicate" => { - log::info!("hCaptcha response expired or duplicate"); - return Ok(false); - } - _ => { - log::warn!("hCaptcha validation failed with error code: {}", error_code); - } - } - } - } - log::info!("hCaptcha validation failed without specific error code"); - return Ok(false); - } - - Ok(true) - } - - /// Validate with reCAPTCHA API - async fn validate_recaptcha(&self, request: &CaptchaValidationRequest) -> Result { - // Use shared HTTP client with keepalive instead of creating new client - let client = get_global_reqwest_client() - .context("Failed to get global HTTP client")?; - - let mut params = HashMap::new(); - params.insert("response", &request.response_token); - params.insert("secret", &request.secret_key); - params.insert("remoteip", &request.ip_address); - - log::info!("reCAPTCHA validation request - response_length: {}, remote_ip: {}", - request.response_token.len(), request.ip_address); - - let response = client - .post("https://www.recaptcha.net/recaptcha/api/siteverify") - .form(¶ms) - .send() - .await - .context("Failed to send reCAPTCHA validation request")?; - - log::info!("reCAPTCHA validation HTTP response - status: {}", response.status()); - - if !response.status().is_success() { - log::error!("reCAPTCHA service returned non-success status: {}", response.status()); - return Ok(false); - } - - let validation_response: CaptchaValidationResponse = response - .json() - .await - .context("Failed to parse reCAPTCHA response")?; - - if !validation_response.success { - if let Some(error_codes) = &validation_response.error_codes { - for error_code in error_codes { - match error_code.as_str() { - "invalid-input-secret" => { - log::error!("reCAPTCHA secret key is invalid"); - return Ok(false); - } - "invalid-input-response" => { - log::info!("Invalid reCAPTCHA response from user"); - return Ok(false); - } - "timeout-or-duplicate" => { - log::info!("reCAPTCHA response expired or duplicate"); - return Ok(false); - } - _ => { - log::warn!("reCAPTCHA validation failed with error code: {}", error_code); - } - } - } - } - log::info!("reCAPTCHA validation failed without specific error code"); - return Ok(false); - } - - Ok(true) - } - - /// Validate with Cloudflare Turnstile API - async fn validate_turnstile(&self, request: &CaptchaValidationRequest) -> Result { - // Use shared HTTP client with keepalive instead of creating new client - let client = get_global_reqwest_client() - .context("Failed to get global HTTP client")?; - - let mut params = HashMap::new(); - params.insert("response", &request.response_token); - params.insert("secret", &request.secret_key); - params.insert("remoteip", &request.ip_address); - - log::info!("Turnstile validation request - response_length: {}, remote_ip: {}", - request.response_token.len(), request.ip_address); - - let response = client - .post("https://challenges.cloudflare.com/turnstile/v0/siteverify") - .form(¶ms) - .send() - .await - .context("Failed to send Turnstile validation request")?; - - log::info!("Turnstile validation HTTP response - status: {}", response.status()); - - if !response.status().is_success() { - log::error!("Turnstile service returned non-success status: {}", response.status()); - return Ok(false); - } - - let validation_response: CaptchaValidationResponse = response - .json() - .await - .context("Failed to parse Turnstile response")?; - - if !validation_response.success { - if let Some(error_codes) = &validation_response.error_codes { - for error_code in error_codes { - match error_code.as_str() { - "invalid-input-secret" => { - log::error!("Turnstile secret key is invalid"); - return Ok(false); - } - "invalid-input-response" => { - log::info!("Invalid Turnstile response from user"); - return Ok(false); - } - "timeout-or-duplicate" => { - log::info!("Turnstile response expired or duplicate"); - return Ok(false); - } - _ => { - log::warn!("Turnstile validation failed with error code: {}", error_code); - } - } - } - } - log::info!("Turnstile validation failed without specific error code"); - return Ok(false); - } - - Ok(true) - } - - /// Get the captcha backend response key name for the current provider - pub fn get_captcha_backend_key(&self) -> &'static str { - match self.config.provider { - CaptchaProvider::HCaptcha => "h-captcha-response", - CaptchaProvider::ReCaptcha => "g-recaptcha-response", - CaptchaProvider::Turnstile => "cf-turnstile-response", - } - } - - /// Get validation result from cache - async fn get_validation_cache(&self, key: &str) -> Option { - let cache = self.validation_cache.read().await; - cache.get(key).cloned() - } - - /// Set validation result in cache - async fn set_validation_cache(&self, key: &str, is_valid: bool) { - let mut cache = self.validation_cache.write().await; - cache.insert( - key.to_string(), - CachedCaptchaResult { - is_valid, - expires_at: Instant::now() + Duration::from_secs(self.config.validation_cache_ttl_seconds), - }, - ); - } - - /// Remove validation result from cache - async fn remove_validation_cache(&self, key: &str) { - let mut cache = self.validation_cache.write().await; - cache.remove(key); - } - - /// Clean up expired cache entries - pub async fn cleanup_cache(&self) { - let mut cache = self.validation_cache.write().await; - let now = Instant::now(); - cache.retain(|_, cached| cached.expires_at > now); - - // Also clean up expired validated tokens - let mut validated_tokens = self.validated_tokens.write().await; - validated_tokens.retain(|_, expiration| *expiration > now); - } -} - -/// Global captcha client instance -static CAPTCHA_CLIENT: OnceCell> = OnceCell::const_new(); - -/// Initialize the global captcha client -pub async fn init_captcha_client( - config: CaptchaConfig, -) -> Result<()> { - let client = Arc::new(CaptchaClient::new(config)); - - CAPTCHA_CLIENT.set(client) - .map_err(|_| anyhow::anyhow!("Failed to initialize captcha client"))?; - - Ok(()) -} - -/// Validate captcha response -pub async fn validate_captcha_response( - response_token: String, - ip_address: String, - user_agent: Option, -) -> Result { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - let request = CaptchaValidationRequest { - response_token, - ip_address, - user_agent: user_agent, - site_key: client.config.site_key.clone(), - secret_key: client.config.secret_key.clone(), - provider: client.config.provider.clone(), - }; - - client.validate_captcha(request).await -} - -/// Generate captcha token -pub async fn generate_captcha_token( - ip_address: String, - user_agent: String, - ja4_fingerprint: Option, -) -> Result { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - client.generate_token(ip_address, user_agent, ja4_fingerprint).await -} - -/// Validate captcha token -pub async fn validate_captcha_token( - token: &str, - ip_address: &str, - user_agent: &str, -) -> Result { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - client.validate_token(token, ip_address, user_agent).await -} - -/// Apply captcha challenge -pub fn apply_captcha_challenge() -> Result { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - Ok(client.apply_captcha_challenge(&client.config.site_key)) -} - -/// Apply captcha challenge with JWT token -pub fn apply_captcha_challenge_with_token(jwt_token: &str) -> Result { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - Ok(client.apply_captcha_challenge_with_token(&client.config.site_key, jwt_token)) -} - -/// Get the captcha backend response key name -pub fn get_captcha_backend_key() -> Result<&'static str> { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - Ok(client.get_captcha_backend_key()) -} - -/// Mark a JWT token as validated after successful captcha completion -pub async fn mark_captcha_token_validated(token: &str) -> Result<()> { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - client.mark_token_validated(token).await -} - -/// Revoke a JWT token -pub async fn revoke_captcha_token(token: &str) -> Result<()> { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - client.revoke_token(token).await -} - -/// Validate captcha response and mark token as validated -pub async fn validate_and_mark_captcha( - response_token: String, - jwt_token: String, - ip_address: String, - user_agent: Option, -) -> Result { - log::info!("validate_and_mark_captcha called for IP: {}, response_token length: {}, jwt_token length: {}", - ip_address, response_token.len(), jwt_token.len()); - - // First validate the captcha response - let is_valid = validate_captcha_response(response_token, ip_address.clone(), user_agent.clone()).await?; - - log::info!("Captcha validation result: {}", is_valid); - - if is_valid { - // Only try to mark JWT token as validated if it's not empty - if !jwt_token.is_empty() { - if let Err(e) = mark_captcha_token_validated(&jwt_token).await { - log::warn!("Failed to mark JWT token as validated: {}", e); - // Don't return false here - captcha validation succeeded - } else { - log::info!("Captcha validated and JWT token marked as validated for IP: {}", ip_address); - } - } else { - log::info!("Captcha validated successfully for IP: {} (no JWT token to mark)", ip_address); - } - } else { - log::warn!("Captcha validation failed for IP: {}", ip_address); - } - - Ok(is_valid) -} - -/// Start periodic cache cleanup task -pub async fn start_cache_cleanup_task() { - tokio::spawn(async { - let mut interval = tokio::time::interval(Duration::from_secs(60)); - loop { - interval.tick().await; - if let Some(client) = CAPTCHA_CLIENT.get() { - client.cleanup_cache().await; - } - } - }); -} +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use anyhow::{Context, Result}; +use chrono::Utc; +use redis::AsyncCommands; +use serde::{Deserialize, Serialize}; +use tokio::sync::{RwLock, OnceCell}; +use jsonwebtoken::{encode, decode, Header, Algorithm, Validation, EncodingKey, DecodingKey}; +use uuid::Uuid; + +use crate::redis::RedisManager; +use crate::http_client::get_global_reqwest_client; + +/// Captcha provider types supported by Gen0Sec +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, clap::ValueEnum)] +pub enum CaptchaProvider { + #[serde(rename = "hcaptcha")] + HCaptcha, + #[serde(rename = "recaptcha")] + ReCaptcha, + #[serde(rename = "turnstile")] + Turnstile, +} + +impl Default for CaptchaProvider { + fn default() -> Self { + CaptchaProvider::HCaptcha + } +} + +impl std::str::FromStr for CaptchaProvider { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "hcaptcha" => Ok(CaptchaProvider::HCaptcha), + "recaptcha" => Ok(CaptchaProvider::ReCaptcha), + "turnstile" => Ok(CaptchaProvider::Turnstile), + _ => Err(anyhow::anyhow!("Invalid captcha provider: {}", s)), + } + } +} + +/// Captcha validation request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CaptchaValidationRequest { + pub response_token: String, + pub ip_address: String, + pub user_agent: Option, + pub site_key: String, + pub secret_key: String, + pub provider: CaptchaProvider, +} + +/// Captcha validation response +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CaptchaValidationResponse { + pub success: bool, + pub error_codes: Option>, + pub challenge_ts: Option, + pub hostname: Option, + pub score: Option, + pub action: Option, +} + +/// JWT Claims for captcha tokens +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CaptchaClaims { + /// Standard JWT claims + pub sub: String, // Subject (user identifier) + pub iss: String, // Issuer + pub aud: String, // Audience + pub exp: i64, // Expiration time + pub iat: i64, // Issued at + pub jti: String, // JWT ID (unique identifier) + + /// Custom captcha claims + pub ip_address: String, + pub user_agent: String, + pub ja4_fingerprint: Option, + pub captcha_provider: String, + pub captcha_validated: bool, +} + +/// Captcha token with JWT-based security +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CaptchaToken { + pub token: String, + pub claims: CaptchaClaims, +} + +/// Cached captcha validation result +#[derive(Debug, Clone)] +pub struct CachedCaptchaResult { + pub is_valid: bool, + pub expires_at: Instant, +} + +/// Captcha action configuration +#[derive(Debug, Clone)] +pub struct CaptchaConfig { + pub site_key: String, + pub secret_key: String, + pub jwt_secret: String, + pub provider: CaptchaProvider, + pub token_ttl_seconds: u64, + pub validation_cache_ttl_seconds: u64, +} + +/// Captcha client for validation and token management +pub struct CaptchaClient { + config: CaptchaConfig, + validation_cache: Arc>>, + validated_tokens: Arc>>, // JTI -> expiration time +} + +impl CaptchaClient { + pub fn new( + config: CaptchaConfig, + ) -> Self { + Self { + config, + validation_cache: Arc::new(RwLock::new(HashMap::new())), + validated_tokens: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Validate a captcha response token + pub async fn validate_captcha(&self, request: CaptchaValidationRequest) -> Result { + log::info!("Starting captcha validation for IP: {}, provider: {:?}", + request.ip_address, self.config.provider); + + // Check if captcha response is provided + if request.response_token.is_empty() { + log::warn!("No captcha response provided for IP: {}", request.ip_address); + return Ok(false); + } + + log::debug!("Captcha response token length: {}", request.response_token.len()); + + // Check validation cache first + let cache_key = format!("{}:{}", request.response_token, request.ip_address); + if let Some(cached) = self.get_validation_cache(&cache_key).await { + if cached.expires_at > Instant::now() { + log::debug!("Captcha validation for {} found in cache", request.ip_address); + return Ok(cached.is_valid); + } else { + self.remove_validation_cache(&cache_key).await; + } + } + + // Validate with provider API + let is_valid = match self.config.provider { + CaptchaProvider::HCaptcha => self.validate_hcaptcha(&request).await?, + CaptchaProvider::ReCaptcha => self.validate_recaptcha(&request).await?, + CaptchaProvider::Turnstile => self.validate_turnstile(&request).await?, + }; + + log::info!("Captcha validation result for IP {}: {}", request.ip_address, is_valid); + + // Cache the result + self.set_validation_cache(&cache_key, is_valid).await; + + Ok(is_valid) + } + + /// Generate a secure JWT captcha token + pub async fn generate_token( + &self, + ip_address: String, + user_agent: String, + ja4_fingerprint: Option, + ) -> Result { + let now = Utc::now(); + let exp = now + chrono::Duration::seconds(self.config.token_ttl_seconds as i64); + let jti = Uuid::new_v4().to_string(); + + let claims = CaptchaClaims { + sub: format!("captcha:{}", ip_address), + iss: "arxignis-synapse".to_string(), + aud: "captcha-validation".to_string(), + exp: exp.timestamp(), + iat: now.timestamp(), + jti: jti.clone(), + ip_address: ip_address.clone(), + user_agent: user_agent.clone(), + ja4_fingerprint, + captcha_provider: format!("{:?}", self.config.provider), + captcha_validated: false, + }; + + let header = Header::new(Algorithm::HS256); + let encoding_key = EncodingKey::from_secret(self.config.jwt_secret.as_bytes()); + + let token = encode(&header, &claims, &encoding_key) + .context("Failed to encode JWT token")?; + + let captcha_token = CaptchaToken { + token: token.clone(), + claims: claims.clone(), + }; + + // Store token in Redis for validation (optional, JWT is self-contained) + if let Ok(redis_manager) = RedisManager::get() { + let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), jti); + let mut redis = redis_manager.get_connection(); + let token_data = serde_json::to_string(&captcha_token) + .context("Failed to serialize captcha token")?; + + let _: () = redis + .set_ex(&key, token_data, self.config.token_ttl_seconds) + .await + .context("Failed to store captcha token in Redis")?; + } + + Ok(captcha_token) + } + + /// Validate a JWT captcha token + pub async fn validate_token(&self, token: &str, ip_address: &str, user_agent: &str) -> Result { + let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); + let mut validation = Validation::new(Algorithm::HS256); + validation.set_audience(&["captcha-validation"]); + + match decode::(token, &decoding_key, &validation) { + Ok(token_data) => { + let claims = token_data.claims; + + // Check if token is expired (JWT handles this automatically, but double-check) + let now = Utc::now().timestamp(); + if claims.exp < now { + log::debug!("JWT token expired"); + return Ok(false); + } + + // Verify IP and User-Agent binding + if claims.ip_address != ip_address || claims.user_agent != user_agent { + log::warn!("JWT token validation failed: IP or User-Agent mismatch"); + return Ok(false); + } + + // Check Redis first for updated token state + let mut captcha_validated = claims.captcha_validated; + log::debug!("Initial JWT token captcha_validated: {}", captcha_validated); + + // Check in-memory cache first (faster) + { + let validated_tokens = self.validated_tokens.read().await; + if let Some(expiration) = validated_tokens.get(&claims.jti) { + if *expiration > Instant::now() { + captcha_validated = true; + log::debug!("Found validated token JTI {} in memory cache", claims.jti); + } else { + log::debug!("Token JTI {} expired in memory cache", claims.jti); + } + } + } + + // If not found in memory cache, check Redis + if !captcha_validated { + if let Ok(redis_manager) = RedisManager::get() { + let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), claims.jti); + log::debug!("Looking up token in Redis with key: {}", key); + + let mut redis = redis_manager.get_connection(); + match redis.get::<_, String>(&key).await { + Ok(token_data_str) => { + log::debug!("Found token data in Redis: {}", token_data_str); + if let Ok(updated_token) = serde_json::from_str::(&token_data_str) { + captcha_validated = updated_token.claims.captcha_validated; + log::debug!("Updated captcha_validated from Redis: {}", captcha_validated); + + // Update memory cache if found in Redis + if captcha_validated { + let expiration = Instant::now() + Duration::from_secs(self.config.token_ttl_seconds); + let mut validated_tokens = self.validated_tokens.write().await; + validated_tokens.insert(claims.jti.clone(), expiration); + } + } else { + log::warn!("Failed to parse token data from Redis"); + } + } + Err(e) => { + log::debug!("Redis token lookup failed for JTI {}: {}", claims.jti, e); + } + } + } else { + log::debug!("Redis manager not available"); + } + } + + // Check if captcha was validated (either from JWT or Redis) + if !captcha_validated { + log::debug!("JWT token not validated for captcha"); + return Ok(false); + } + + // Optional: Check Redis blacklist for revoked tokens + if let Ok(redis_manager) = RedisManager::get() { + let blacklist_key = format!("{}:captcha_blacklist:{}", redis_manager.create_namespace("captcha"), claims.jti); + let mut redis = redis_manager.get_connection(); + match redis.exists::<_, bool>(&blacklist_key).await { + Ok(true) => { + log::debug!("JWT token {} is blacklisted", claims.jti); + return Ok(false); + } + Ok(false) => { + // Token not blacklisted, continue validation + } + Err(e) => { + log::warn!("Redis blacklist check error for JWT {}: {}", claims.jti, e); + // Continue validation despite Redis error + } + } + } + + Ok(true) + } + Err(e) => { + log::warn!("JWT token validation failed: {}", e); + Ok(false) + } + } + } + + /// Mark a JWT token as validated after successful captcha completion + pub async fn mark_token_validated(&self, token: &str) -> Result<()> { + let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); + let mut validation = Validation::new(Algorithm::HS256); + validation.set_audience(&["captcha-validation"]); + + match decode::(token, &decoding_key, &validation) { + Ok(token_data) => { + let claims = token_data.claims; + + // Store the JTI as validated in memory cache + let expiration = Instant::now() + Duration::from_secs(self.config.token_ttl_seconds); + { + let mut validated_tokens = self.validated_tokens.write().await; + validated_tokens.insert(claims.jti.clone(), expiration); + log::debug!("Marked token JTI {} as validated, expires at {:?}", claims.jti, expiration); + } + + // Also update Redis cache if available (for persistence across restarts) + if let Ok(redis_manager) = RedisManager::get() { + let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), claims.jti); + log::debug!("Storing updated token in Redis with key: {}", key); + + let mut redis = redis_manager.get_connection(); + let mut updated_claims = claims.clone(); + updated_claims.captcha_validated = true; + + let updated_captcha_token = CaptchaToken { + token: token.to_string(), + claims: updated_claims, + }; + let token_data = serde_json::to_string(&updated_captcha_token) + .context("Failed to serialize updated captcha token")?; + + log::debug!("Token data to store: {}", token_data); + + let _: () = redis + .set_ex(&key, token_data, self.config.token_ttl_seconds) + .await + .context("Failed to update captcha token in Redis")?; + + log::debug!("Successfully stored updated token in Redis"); + } else { + log::debug!("Redis manager not available for token storage"); + } + + Ok(()) + } + Err(e) => { + log::warn!("Failed to decode JWT token for validation marking: {}", e); + Err(anyhow::anyhow!("Invalid JWT token: {}", e)) + } + } + } + + /// Revoke a JWT token by adding it to blacklist + pub async fn revoke_token(&self, token: &str) -> Result<()> { + let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); + let mut validation = Validation::new(Algorithm::HS256); + validation.set_audience(&["captcha-validation"]); + + match decode::(token, &decoding_key, &validation) { + Ok(token_data) => { + let claims = token_data.claims; + + // Add to Redis blacklist + if let Ok(redis_manager) = RedisManager::get() { + let blacklist_key = format!("{}:captcha_blacklist:{}", redis_manager.create_namespace("captcha"), claims.jti); + let mut redis = redis_manager.get_connection(); + let _: () = redis + .set_ex(&blacklist_key, "revoked", self.config.token_ttl_seconds) + .await + .context("Failed to add token to blacklist")?; + } + + Ok(()) + } + Err(e) => { + log::warn!("Failed to decode JWT token for revocation: {}", e); + Err(anyhow::anyhow!("Invalid JWT token: {}", e)) + } + } + } + + /// Apply captcha challenge (return HTML form) + pub fn apply_captcha_challenge(&self, site_key: &str) -> String { + self.render_captcha_template(site_key, None) + } + + /// Apply captcha challenge with JWT token (return HTML form) + pub fn apply_captcha_challenge_with_token(&self, site_key: &str, jwt_token: &str) -> String { + self.render_captcha_template(site_key, Some(jwt_token)) + } + + /// Render captcha template based on provider + fn render_captcha_template(&self, site_key: &str, jwt_token: Option<&str>) -> String { + let (frontend_js, frontend_key, callback_attr) = match self.config.provider { + CaptchaProvider::HCaptcha => ( + "https://js.hcaptcha.com/1/api.js", + "h-captcha", + "data-callback=\"captchaCallback\"" + ), + CaptchaProvider::ReCaptcha => ( + "https://www.recaptcha.net/recaptcha/api.js", + "g-recaptcha", + "data-callback=\"captchaCallback\"" + ), + CaptchaProvider::Turnstile => ( + "https://challenges.cloudflare.com/turnstile/v0/api.js", + "cf-turnstile", + "data-callback=\"onTurnstileSuccess\" data-error-callback=\"onTurnstileError\"" + ), + }; + + let jwt_token_input = if let Some(token) = jwt_token { + format!(r#""#, token) + } else { + r#""#.to_string() + }; + + let html_template = format!( + r#" + + + Gen0Sec Captcha + + + + + + +
+
+

Gen0Sec Captcha

+

Please complete the security verification below to continue.

+
+ +
+
+
+
+ + {} +
+ + +
+ + + +"#, + frontend_js, + frontend_key, + site_key, + callback_attr, + jwt_token_input + ); + html_template + } + + /// Validate with hCaptcha API + async fn validate_hcaptcha(&self, request: &CaptchaValidationRequest) -> Result { + // Use shared HTTP client with keepalive instead of creating new client + let client = get_global_reqwest_client() + .context("Failed to get global HTTP client")?; + + let mut params = HashMap::new(); + params.insert("response", &request.response_token); + params.insert("secret", &request.secret_key); + params.insert("sitekey", &request.site_key); + params.insert("remoteip", &request.ip_address); + + log::info!("hCaptcha validation request - response_length: {}, remote_ip: {}", + request.response_token.len(), request.ip_address); + + let response = client + .post("https://hcaptcha.com/siteverify") + .form(¶ms) + .send() + .await + .context("Failed to send hCaptcha validation request")?; + + log::info!("hCaptcha validation HTTP response - status: {}", response.status()); + + if !response.status().is_success() { + log::error!("hCaptcha service returned non-success status: {}", response.status()); + return Ok(false); + } + + let validation_response: CaptchaValidationResponse = response + .json() + .await + .context("Failed to parse hCaptcha response")?; + + if !validation_response.success { + if let Some(error_codes) = &validation_response.error_codes { + for error_code in error_codes { + match error_code.as_str() { + "invalid-input-secret" => { + log::error!("hCaptcha secret key is invalid"); + return Ok(false); + } + "invalid-input-response" => { + log::info!("Invalid hCaptcha response from user"); + return Ok(false); + } + "timeout-or-duplicate" => { + log::info!("hCaptcha response expired or duplicate"); + return Ok(false); + } + _ => { + log::warn!("hCaptcha validation failed with error code: {}", error_code); + } + } + } + } + log::info!("hCaptcha validation failed without specific error code"); + return Ok(false); + } + + Ok(true) + } + + /// Validate with reCAPTCHA API + async fn validate_recaptcha(&self, request: &CaptchaValidationRequest) -> Result { + // Use shared HTTP client with keepalive instead of creating new client + let client = get_global_reqwest_client() + .context("Failed to get global HTTP client")?; + + let mut params = HashMap::new(); + params.insert("response", &request.response_token); + params.insert("secret", &request.secret_key); + params.insert("remoteip", &request.ip_address); + + log::info!("reCAPTCHA validation request - response_length: {}, remote_ip: {}", + request.response_token.len(), request.ip_address); + + let response = client + .post("https://www.recaptcha.net/recaptcha/api/siteverify") + .form(¶ms) + .send() + .await + .context("Failed to send reCAPTCHA validation request")?; + + log::info!("reCAPTCHA validation HTTP response - status: {}", response.status()); + + if !response.status().is_success() { + log::error!("reCAPTCHA service returned non-success status: {}", response.status()); + return Ok(false); + } + + let validation_response: CaptchaValidationResponse = response + .json() + .await + .context("Failed to parse reCAPTCHA response")?; + + if !validation_response.success { + if let Some(error_codes) = &validation_response.error_codes { + for error_code in error_codes { + match error_code.as_str() { + "invalid-input-secret" => { + log::error!("reCAPTCHA secret key is invalid"); + return Ok(false); + } + "invalid-input-response" => { + log::info!("Invalid reCAPTCHA response from user"); + return Ok(false); + } + "timeout-or-duplicate" => { + log::info!("reCAPTCHA response expired or duplicate"); + return Ok(false); + } + _ => { + log::warn!("reCAPTCHA validation failed with error code: {}", error_code); + } + } + } + } + log::info!("reCAPTCHA validation failed without specific error code"); + return Ok(false); + } + + Ok(true) + } + + /// Validate with Cloudflare Turnstile API + async fn validate_turnstile(&self, request: &CaptchaValidationRequest) -> Result { + // Use shared HTTP client with keepalive instead of creating new client + let client = get_global_reqwest_client() + .context("Failed to get global HTTP client")?; + + let mut params = HashMap::new(); + params.insert("response", &request.response_token); + params.insert("secret", &request.secret_key); + params.insert("remoteip", &request.ip_address); + + log::info!("Turnstile validation request - response_length: {}, remote_ip: {}", + request.response_token.len(), request.ip_address); + + let response = client + .post("https://challenges.cloudflare.com/turnstile/v0/siteverify") + .form(¶ms) + .send() + .await + .context("Failed to send Turnstile validation request")?; + + log::info!("Turnstile validation HTTP response - status: {}", response.status()); + + if !response.status().is_success() { + log::error!("Turnstile service returned non-success status: {}", response.status()); + return Ok(false); + } + + let validation_response: CaptchaValidationResponse = response + .json() + .await + .context("Failed to parse Turnstile response")?; + + if !validation_response.success { + if let Some(error_codes) = &validation_response.error_codes { + for error_code in error_codes { + match error_code.as_str() { + "invalid-input-secret" => { + log::error!("Turnstile secret key is invalid"); + return Ok(false); + } + "invalid-input-response" => { + log::info!("Invalid Turnstile response from user"); + return Ok(false); + } + "timeout-or-duplicate" => { + log::info!("Turnstile response expired or duplicate"); + return Ok(false); + } + _ => { + log::warn!("Turnstile validation failed with error code: {}", error_code); + } + } + } + } + log::info!("Turnstile validation failed without specific error code"); + return Ok(false); + } + + Ok(true) + } + + /// Get the captcha backend response key name for the current provider + pub fn get_captcha_backend_key(&self) -> &'static str { + match self.config.provider { + CaptchaProvider::HCaptcha => "h-captcha-response", + CaptchaProvider::ReCaptcha => "g-recaptcha-response", + CaptchaProvider::Turnstile => "cf-turnstile-response", + } + } + + /// Get validation result from cache + async fn get_validation_cache(&self, key: &str) -> Option { + let cache = self.validation_cache.read().await; + cache.get(key).cloned() + } + + /// Set validation result in cache + async fn set_validation_cache(&self, key: &str, is_valid: bool) { + let mut cache = self.validation_cache.write().await; + cache.insert( + key.to_string(), + CachedCaptchaResult { + is_valid, + expires_at: Instant::now() + Duration::from_secs(self.config.validation_cache_ttl_seconds), + }, + ); + } + + /// Remove validation result from cache + async fn remove_validation_cache(&self, key: &str) { + let mut cache = self.validation_cache.write().await; + cache.remove(key); + } + + /// Clean up expired cache entries + pub async fn cleanup_cache(&self) { + let mut cache = self.validation_cache.write().await; + let now = Instant::now(); + cache.retain(|_, cached| cached.expires_at > now); + + // Also clean up expired validated tokens + let mut validated_tokens = self.validated_tokens.write().await; + validated_tokens.retain(|_, expiration| *expiration > now); + } +} + +/// Global captcha client instance +static CAPTCHA_CLIENT: OnceCell> = OnceCell::const_new(); + +/// Initialize the global captcha client +pub async fn init_captcha_client( + config: CaptchaConfig, +) -> Result<()> { + let client = Arc::new(CaptchaClient::new(config)); + + CAPTCHA_CLIENT.set(client) + .map_err(|_| anyhow::anyhow!("Failed to initialize captcha client"))?; + + Ok(()) +} + +/// Validate captcha response +pub async fn validate_captcha_response( + response_token: String, + ip_address: String, + user_agent: Option, +) -> Result { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + let request = CaptchaValidationRequest { + response_token, + ip_address, + user_agent: user_agent, + site_key: client.config.site_key.clone(), + secret_key: client.config.secret_key.clone(), + provider: client.config.provider.clone(), + }; + + client.validate_captcha(request).await +} + +/// Generate captcha token +pub async fn generate_captcha_token( + ip_address: String, + user_agent: String, + ja4_fingerprint: Option, +) -> Result { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + client.generate_token(ip_address, user_agent, ja4_fingerprint).await +} + +/// Validate captcha token +pub async fn validate_captcha_token( + token: &str, + ip_address: &str, + user_agent: &str, +) -> Result { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + client.validate_token(token, ip_address, user_agent).await +} + +/// Apply captcha challenge +pub fn apply_captcha_challenge() -> Result { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + Ok(client.apply_captcha_challenge(&client.config.site_key)) +} + +/// Apply captcha challenge with JWT token +pub fn apply_captcha_challenge_with_token(jwt_token: &str) -> Result { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + Ok(client.apply_captcha_challenge_with_token(&client.config.site_key, jwt_token)) +} + +/// Get the captcha backend response key name +pub fn get_captcha_backend_key() -> Result<&'static str> { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + Ok(client.get_captcha_backend_key()) +} + +/// Mark a JWT token as validated after successful captcha completion +pub async fn mark_captcha_token_validated(token: &str) -> Result<()> { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + client.mark_token_validated(token).await +} + +/// Revoke a JWT token +pub async fn revoke_captcha_token(token: &str) -> Result<()> { + let client = CAPTCHA_CLIENT + .get() + .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; + + client.revoke_token(token).await +} + +/// Validate captcha response and mark token as validated +pub async fn validate_and_mark_captcha( + response_token: String, + jwt_token: String, + ip_address: String, + user_agent: Option, +) -> Result { + log::info!("validate_and_mark_captcha called for IP: {}, response_token length: {}, jwt_token length: {}", + ip_address, response_token.len(), jwt_token.len()); + + // First validate the captcha response + let is_valid = validate_captcha_response(response_token, ip_address.clone(), user_agent.clone()).await?; + + log::info!("Captcha validation result: {}", is_valid); + + if is_valid { + // Only try to mark JWT token as validated if it's not empty + if !jwt_token.is_empty() { + if let Err(e) = mark_captcha_token_validated(&jwt_token).await { + log::warn!("Failed to mark JWT token as validated: {}", e); + // Don't return false here - captcha validation succeeded + } else { + log::info!("Captcha validated and JWT token marked as validated for IP: {}", ip_address); + } + } else { + log::info!("Captcha validated successfully for IP: {} (no JWT token to mark)", ip_address); + } + } else { + log::warn!("Captcha validation failed for IP: {}", ip_address); + } + + Ok(is_valid) +} + +/// Start periodic cache cleanup task +pub async fn start_cache_cleanup_task() { + tokio::spawn(async { + let mut interval = tokio::time::interval(Duration::from_secs(60)); + loop { + interval.tick().await; + if let Some(client) = CAPTCHA_CLIENT.get() { + client.cleanup_cache().await; + } + } + }); +} diff --git a/src/waf/actions/mod.rs b/src/waf/actions/mod.rs index fabe28f..53404ed 100644 --- a/src/waf/actions/mod.rs +++ b/src/waf/actions/mod.rs @@ -1,2 +1,2 @@ -pub mod captcha; -pub mod block; +pub mod captcha; +pub mod block; diff --git a/src/waf/mod.rs b/src/waf/mod.rs index 056cfb0..f3b0114 100644 --- a/src/waf/mod.rs +++ b/src/waf/mod.rs @@ -1,2 +1,2 @@ -pub mod wirefilter; -pub mod actions; +pub mod wirefilter; +pub mod actions; diff --git a/src/waf/wirefilter.rs b/src/waf/wirefilter.rs index edd8ad0..912efad 100644 --- a/src/waf/wirefilter.rs +++ b/src/waf/wirefilter.rs @@ -1,1110 +1,1110 @@ -use std::collections::HashSet; -use std::net::SocketAddr; -use std::sync::{Arc, RwLock, OnceLock}; - -use anyhow::Result; -use sha2::{Digest, Sha256}; -use wirefilter::{ExecutionContext, Scheme, TypedArray, TypedMap}; -use crate::worker::config::{Config, fetch_config}; -use crate::threat; -use anyhow::anyhow; - -/// String interner for WAF rule expressions -/// This deduplicates expression strings to minimize memory usage. -/// Each unique expression is stored once and reused across rule updates. -/// Note: Interned strings are leaked to satisfy wirefilter's 'static lifetime requirement, -/// but deduplication ensures each unique expression is only leaked once. -struct ExpressionInterner { - // Set of already interned expression hashes to avoid re-leaking - interned: HashSet, -} - -impl ExpressionInterner { - fn new() -> Self { - Self { - interned: HashSet::new(), - } - } - - /// Intern an expression string, returning a &'static str reference - /// If the expression was already interned, recomputes the hash but doesn't re-leak - /// (The hash check prevents re-leaking the same expression) - fn intern(&mut self, expr: &str) -> &'static str { - use std::hash::{Hash, Hasher}; - use std::collections::hash_map::DefaultHasher; - - let mut hasher = DefaultHasher::new(); - expr.hash(&mut hasher); - let hash = hasher.finish(); - - if self.interned.contains(&hash) { - // Already interned, but we still need to return a 'static reference - // We must leak again since we don't store the actual pointers - // This is a trade-off: we track hashes to log warnings, but still leak - log::trace!("Re-using previously interned expression (hash collision or same expr)"); - } - - self.interned.insert(hash); - Box::leak(expr.to_string().into_boxed_str()) - } - - /// Get the count of unique expressions interned - fn count(&self) -> usize { - self.interned.len() - } -} - -/// Global expression interner to deduplicate WAF rule expressions across updates -static EXPRESSION_INTERNER: OnceLock> = OnceLock::new(); - -fn get_expression_interner() -> &'static RwLock { - EXPRESSION_INTERNER.get_or_init(|| RwLock::new(ExpressionInterner::new())) -} - -/// Intern a WAF rule expression, deduplicating across all rules -fn intern_expression(expr: &str) -> &'static str { - let mut interner = get_expression_interner().write().unwrap(); - interner.intern(expr) -} - -/// Get the count of interned expressions (for diagnostics) -pub fn get_interned_expression_count() -> usize { - if let Some(interner) = EXPRESSION_INTERNER.get() { - interner.read().unwrap().count() - } else { - 0 - } -} - -/// WAF action types -#[derive(Debug, Clone, PartialEq)] -pub enum WafAction { - Block, - Challenge, - RateLimit, - Allow, -} - -impl WafAction { - pub fn from_str(action: &str) -> Self { - match action.to_lowercase().as_str() { - "block" => WafAction::Block, - "challenge" => WafAction::Challenge, - "ratelimit" => WafAction::RateLimit, - _ => WafAction::Allow, - } - } -} - -/// WAF rule evaluation result -#[derive(Debug, Clone)] -pub struct WafResult { - pub action: WafAction, - pub rule_name: String, - pub rule_id: String, - pub rate_limit_config: Option, - pub threat_response: Option, -} - -/// Wirefilter-based HTTP request filtering engine -pub struct HttpFilter { - scheme: Arc, - rules: Arc)>>>, // (filter, action, name, id, rate_limit_config) - rules_hash: Arc>>, -} - -impl HttpFilter { - /// Create the wirefilter scheme with HTTP request fields - fn create_scheme() -> Scheme { - let mut builder = Scheme! { - http.request.method: Bytes, - http.request.scheme: Bytes, - http.request.host: Bytes, - http.request.port: Int, - http.request.path: Bytes, - http.request.uri: Bytes, - http.request.query: Bytes, - http.request.user_agent: Bytes, - http.request.content_type: Bytes, - http.request.content_length: Int, - http.request.body: Bytes, - http.request.body_sha256: Bytes, - http.request.headers: Map(Array(Bytes)), - ip.src: Ip, - ip.src.country: Bytes, - ip.src.asn: Int, - ip.src.asn_org: Bytes, - ip.src.asn_country: Bytes, - threat.score: Int, - threat.advice: Bytes, - signal.ja4: Bytes, - signal.ja4_raw: Bytes, - signal.ja4_unsorted: Bytes, - signal.ja4_raw_unsorted: Bytes, - signal.tls_version: Bytes, - signal.cipher_suite: Bytes, - signal.sni: Bytes, - signal.alpn: Bytes, - signal.ja4h: Bytes, - signal.ja4h_method: Bytes, - signal.ja4h_version: Bytes, - signal.ja4h_has_cookie: Int, - signal.ja4h_has_referer: Int, - signal.ja4h_header_count: Int, - signal.ja4h_language: Bytes, - signal.ja4t: Bytes, - signal.ja4t_window_size: Int, - signal.ja4t_ttl: Int, - signal.ja4t_mss: Int, - signal.ja4t_window_scale: Int, - signal.ja4l_client: Bytes, - signal.ja4l_server: Bytes, - signal.ja4l_syn_time: Int, - signal.ja4l_synack_time: Int, - signal.ja4l_ack_time: Int, - signal.ja4l_ttl_client: Int, - signal.ja4l_ttl_server: Int, - signal.ja4s: Bytes, - signal.ja4s_proto: Bytes, - signal.ja4s_version: Bytes, - signal.ja4s_cipher: Int, - signal.ja4s_alpn: Bytes, - signal.ja4x: Bytes, - signal.ja4x_issuer_rdns: Bytes, - signal.ja4x_subject_rdns: Bytes, - signal.ja4x_extensions: Bytes, - }; - - // Register functions used in Cloudflare-style expressions - builder.add_function("any", wirefilter::AnyFunction::default()).unwrap(); - builder.add_function("all", wirefilter::AllFunction::default()).unwrap(); - - builder.add_function("cidr", wirefilter::CIDRFunction::default()).unwrap(); - builder.add_function("concat", wirefilter::ConcatFunction::default()).unwrap(); - builder.add_function("decode_base64", wirefilter::DecodeBase64Function::default()).unwrap(); - builder.add_function("ends_with", wirefilter::EndsWithFunction::default()).unwrap(); - builder.add_function("json_lookup_integer", wirefilter::JsonLookupIntegerFunction::default()).unwrap(); - builder.add_function("json_lookup_string", wirefilter::JsonLookupStringFunction::default()).unwrap(); - builder.add_function("len", wirefilter::LenFunction::default()).unwrap(); - builder.add_function("lower", wirefilter::LowerFunction::default()).unwrap(); - builder.add_function("remove_bytes", wirefilter::RemoveBytesFunction::default()).unwrap(); - builder.add_function("remove_query_args", wirefilter::RemoveQueryArgsFunction::default()).unwrap(); - builder.add_function("starts_with", wirefilter::StartsWithFunction::default()).unwrap(); - builder.add_function("substring", wirefilter::SubstringFunction::default()).unwrap(); - builder.add_function("to_string", wirefilter::ToStringFunction::default()).unwrap(); - builder.add_function("upper", wirefilter::UpperFunction::default()).unwrap(); - builder.add_function("url_decode", wirefilter::UrlDecodeFunction::default()).unwrap(); - builder.add_function("uuid4", wirefilter::UUID4Function::default()).unwrap(); - builder.add_function("wildcard_replace", wirefilter::WildcardReplaceFunction::default()).unwrap(); - - - builder.build() - } - - /// Create a new HTTP filter with the given filter expression (static version) - pub fn new(filter_expr: &'static str) -> Result { - // Create the scheme with HTTP request fields - let scheme = Arc::new(Self::create_scheme()); - - // Parse the filter expression - let ast = scheme.parse(filter_expr)?; - - // Compile the filter - let filter = ast.compile(); - - Ok(Self { - scheme, - rules: Arc::new(RwLock::new(vec![ - (filter, WafAction::Block, "default".to_string(), "default".to_string(), None) - ])), - rules_hash: Arc::new(RwLock::new(None)), - }) - } - - /// Create a new HTTP filter from config WAF rules - pub fn new_from_config(config: &Config) -> Result { - // Create the scheme with HTTP request fields - let scheme = Arc::new(Self::create_scheme()); - - if config.waf_rules.rules.is_empty() { - // If no WAF rules, create a default filter that allows all - return Ok(Self { - scheme, - rules: Arc::new(RwLock::new(vec![])), - rules_hash: Arc::new(RwLock::new(Some(Self::compute_rules_hash("")))), - }); - } - - // Validate and compile individual WAF rules - let mut compiled_rules = Vec::new(); - let mut rules_hash_input = String::new(); - - for rule in &config.waf_rules.rules { - // Basic validation - check if expression is not empty - if rule.expression.trim().is_empty() { - log::warn!("Skipping empty WAF rule expression for rule '{}'", rule.name); - continue; - } - - // Try to parse the expression to validate it - if let Err(error) = scheme.parse(&rule.expression) { - log::warn!("Invalid WAF rule expression for rule '{}': {}: {}", rule.name, rule.expression, error); - continue; - } - - // Compile the rule using interned expression to minimize memory leakage - let expression = intern_expression(&rule.expression); - let ast = scheme.parse(expression)?; - let filter = ast.compile(); - let action = WafAction::from_str(&rule.action); - - // Parse rate limit config if action is RateLimit - let rate_limit_config = if action == WafAction::RateLimit { - rule.config.as_ref().and_then(|cfg| { - match crate::worker::config::RateLimitConfig::from_json(cfg) { - Ok(config) => { - log::debug!("Parsed rate limit config for rule {}: period={}, requests={}", - rule.id, config.period, config.requests); - Some(config) - } - Err(e) => { - log::error!("Failed to parse rate limit config for rule {}: {}. Config JSON: {}", - rule.id, e, serde_json::to_string(cfg).unwrap_or_else(|_| "invalid json".to_string())); - None - } - } - }) - } else { - None - }; - - compiled_rules.push((filter, action, rule.name.clone(), rule.id.clone(), rate_limit_config)); - rules_hash_input.push_str(&format!("{}:{}:{};", rule.id, rule.action, rule.expression)); - } - - if compiled_rules.is_empty() { - log::warn!("No valid WAF rules found, using default filter that allows all"); - return Ok(Self { - scheme, - rules: Arc::new(RwLock::new(vec![])), - rules_hash: Arc::new(RwLock::new(Some(Self::compute_rules_hash("")))), - }); - } - - let hash = Self::compute_rules_hash(&rules_hash_input); - log::debug!("WAF expression interner now has {} unique expressions", get_interned_expression_count()); - Ok(Self { - scheme, - rules: Arc::new(RwLock::new(compiled_rules)), - rules_hash: Arc::new(RwLock::new(Some(hash))), - }) - } - - /// Update the filter with new WAF rules from config - pub fn update_from_config(&self, config: &Config) -> Result<()> { - // Validate and compile individual WAF rules - let mut compiled_rules = Vec::new(); - let mut rules_hash_input = String::new(); - - for rule in &config.waf_rules.rules { - // Basic validation - check if expression is not empty - if rule.expression.trim().is_empty() { - log::warn!("Skipping empty WAF rule expression for rule '{}'", rule.name); - continue; - } - - // Try to parse the expression to validate it - if let Err(error) = self.scheme.parse(&rule.expression) { - log::warn!("Invalid WAF rule expression for rule '{}': {}: {}", rule.name, rule.expression, error); - continue; - } - - // Compile the rule using interned expression to minimize memory leakage - let expression = intern_expression(&rule.expression); - let ast = self.scheme.parse(expression)?; - let filter = ast.compile(); - let action = WafAction::from_str(&rule.action); - - // Parse rate limit config if action is RateLimit - let rate_limit_config = if action == WafAction::RateLimit { - rule.config.as_ref().and_then(|cfg| { - match crate::worker::config::RateLimitConfig::from_json(cfg) { - Ok(config) => { - log::debug!("Parsed rate limit config for rule {}: period={}, requests={}", - rule.id, config.period, config.requests); - Some(config) - } - Err(e) => { - log::error!("Failed to parse rate limit config for rule {}: {}. Config JSON: {}", - rule.id, e, serde_json::to_string(cfg).unwrap_or_else(|_| "invalid json".to_string())); - None - } - } - }) - } else { - None - }; - - compiled_rules.push((filter, action, rule.name.clone(), rule.id.clone(), rate_limit_config)); - rules_hash_input.push_str(&format!("{}:{}:{};", rule.id, rule.action, rule.expression)); - } - - // Compute hash and skip update if unchanged - let new_hash = Self::compute_rules_hash(&rules_hash_input); - if let Some(prev) = self.rules_hash.read().unwrap().as_ref() { - if prev == &new_hash { - log::debug!("HTTP filter WAF rules unchanged; skipping update"); - return Ok(()); - } - } - - let rules_count = compiled_rules.len(); - *self.rules.write().unwrap() = compiled_rules; - *self.rules_hash.write().unwrap() = Some(new_hash); - - log::info!("HTTP filter updated with {} WAF rules from config (expression interner: {} unique expressions)", - rules_count, get_interned_expression_count()); - - Ok(()) - } - - fn compute_rules_hash(expr: &str) -> String { - let mut hasher = Sha256::new(); - hasher.update(expr.as_bytes()); - hex::encode(hasher.finalize()) - } - - /// Get the current filter expression (for debugging) - pub fn get_current_expression(&self) -> String { - // This is a simplified version - in practice you might want to store the original expression - "dynamic_filter_from_config".to_string() - } - - /// Check if the given HTTP request should be blocked using request parts and body bytes - pub async fn should_block_request_from_parts( - &self, - req_parts: &hyper::http::request::Parts, - body_bytes: &[u8], - peer_addr: SocketAddr, - ) -> Result> { - // Create execution context - let mut ctx = ExecutionContext::new(&self.scheme); - - // Extract request information - let method = req_parts.method.as_str(); - let uri = &req_parts.uri; - let scheme = uri.scheme().map(|s| s.as_str()).unwrap_or("http"); - let host = uri.host().unwrap_or("").to_string(); - let port = uri.port_u16().unwrap_or(if scheme == "https" { 443 } else { 80 }); - let path = uri.path().to_string(); - let full_uri = uri.to_string(); - let query = uri.query().unwrap_or("").to_string(); - - // Extract headers - let user_agent = req_parts - .headers - .get("user-agent") - .and_then(|h| h.to_str().ok()) - .unwrap_or("") - .to_string(); - - let content_type = req_parts - .headers - .get("content-type") - .and_then(|h| h.to_str().ok()) - .unwrap_or("") - .to_string(); - - // Get content length - let content_length = req_parts - .headers - .get("content-length") - .and_then(|h| h.to_str().ok()) - .and_then(|s| s.parse::().ok()) - .unwrap_or(body_bytes.len() as i64); - - // Process request body - let body_text = String::from_utf8_lossy(body_bytes).to_string(); - - // Calculate body SHA256 - let mut hasher = Sha256::new(); - hasher.update(body_bytes); - let body_sha256_hex = hex::encode(hasher.finalize()); - - // Set field values in execution context - ctx.set_field_value( - self.scheme.get_field("http.request.method").unwrap(), - method, - )?; - ctx.set_field_value( - self.scheme.get_field("http.request.scheme").unwrap(), - scheme, - )?; - ctx.set_field_value( - self.scheme.get_field("http.request.host").unwrap(), - host, - )?; - ctx.set_field_value( - self.scheme.get_field("http.request.port").unwrap(), - port as i64, - )?; - ctx.set_field_value( - self.scheme.get_field("http.request.path").unwrap(), - path, - )?; - ctx.set_field_value( - self.scheme.get_field("http.request.uri").unwrap(), - full_uri, - )?; - ctx.set_field_value( - self.scheme.get_field("http.request.query").unwrap(), - query, - )?; - ctx.set_field_value( - self.scheme.get_field("http.request.user_agent").unwrap(), - user_agent, - )?; - ctx.set_field_value( - self.scheme.get_field("http.request.content_type").unwrap(), - content_type, - )?; - ctx.set_field_value( - self.scheme.get_field("http.request.headers").unwrap(), - { - let mut headers_map: TypedMap<'_, TypedArray<'_, &[u8]>> = TypedMap::new(); - for (name, value) in req_parts.headers.iter() { - let key = name.as_str().to_ascii_lowercase().into_bytes().into_boxed_slice(); - let entry = headers_map.get_or_insert(key, TypedArray::new()); - match value.to_str() { - Ok(s) => entry.push(s.as_bytes()), - Err(_) => entry.push(value.as_bytes()), - } - } - headers_map - }, - )?; - ctx.set_field_value( - self.scheme.get_field("http.request.content_length").unwrap(), - content_length, - )?; - ctx.set_field_value( - self.scheme.get_field("http.request.body").unwrap(), - body_text, - )?; - ctx.set_field_value( - self.scheme.get_field("http.request.body_sha256").unwrap(), - body_sha256_hex, - )?; - ctx.set_field_value( - self.scheme.get_field("ip.src").unwrap(), - peer_addr.ip(), - )?; - - // Fetch threat intelligence data for the source IP - // Fetch full threat response for access logging, and WAF fields for rule evaluation - log::info!("🔍 [WAF] Looking up GeoIP for: {}", peer_addr.ip()); - let threat_response = threat::get_threat_intel(&peer_addr.ip().to_string()).await.ok().flatten(); - let _threat_fields = if let Some(ref threat_resp) = threat_response { - let waf_fields = threat::WafFields::from(threat_resp); - log::info!("🔍 [WAF] GeoIP Result: Country='{}', ASN={}, ThreatScore={}", - waf_fields.ip_src_country, waf_fields.ip_src_asn, waf_fields.threat_score); - // Set threat intelligence fields - ctx.set_field_value( - self.scheme.get_field("ip.src.country").unwrap(), - waf_fields.ip_src_country.clone(), - )?; - ctx.set_field_value( - self.scheme.get_field("ip.src.asn").unwrap(), - waf_fields.ip_src_asn as i64, - )?; - ctx.set_field_value( - self.scheme.get_field("ip.src.asn_org").unwrap(), - waf_fields.ip_src_asn_org.clone(), - )?; - ctx.set_field_value( - self.scheme.get_field("ip.src.asn_country").unwrap(), - waf_fields.ip_src_asn_country.clone(), - )?; - ctx.set_field_value( - self.scheme.get_field("threat.score").unwrap(), - waf_fields.threat_score as i64, - )?; - ctx.set_field_value( - self.scheme.get_field("threat.advice").unwrap(), - waf_fields.threat_advice.clone(), - )?; - Some(waf_fields) - } else { - // No threat data found, set default values - log::warn!("🔍 [WAF] No GeoIP data found for {}, setting empty country", peer_addr.ip()); - ctx.set_field_value( - self.scheme.get_field("ip.src.country").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("ip.src.asn").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("ip.src.asn_org").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("ip.src.asn_country").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("threat.score").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("threat.advice").unwrap(), - "", - )?; - None - }; - - // Extract HTTP version - let http_version = format!("{:?}", req_parts.version); - - // Generate JA4H fingerprint from HTTP request (available now) - let ja4h_fp = crate::ja4_plus::Ja4hFingerprint::from_http_request( - method, - &http_version, - &req_parts.headers, - ); - - // Set default empty values for all signal (JA4) fields - // These fields will be populated when JA4 data is available - ctx.set_field_value( - self.scheme.get_field("signal.ja4").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4_raw").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4_unsorted").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4_raw_unsorted").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.tls_version").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.cipher_suite").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.sni").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.alpn").unwrap(), - "", - )?; - // Populate JA4H fields from generated fingerprint - ctx.set_field_value( - self.scheme.get_field("signal.ja4h").unwrap(), - ja4h_fp.fingerprint.clone(), - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4h_method").unwrap(), - ja4h_fp.method.clone(), - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4h_version").unwrap(), - ja4h_fp.version.clone(), - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4h_has_cookie").unwrap(), - if ja4h_fp.has_cookie { 1i64 } else { 0i64 }, - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4h_has_referer").unwrap(), - if ja4h_fp.has_referer { 1i64 } else { 0i64 }, - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4h_header_count").unwrap(), - ja4h_fp.header_count as i64, - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4h_language").unwrap(), - ja4h_fp.language.clone(), - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4t").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4t_window_size").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4t_ttl").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4t_mss").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4t_window_scale").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4l_client").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4l_server").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4l_syn_time").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4l_synack_time").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4l_ack_time").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4l_ttl_client").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4l_ttl_server").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4s").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4s_proto").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4s_version").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4s_cipher").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4s_alpn").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4x").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4x_issuer_rdns").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4x_subject_rdns").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4x_extensions").unwrap(), - "", - )?; - - // Execute each rule individually and return the first match - let rules_guard = self.rules.read().unwrap(); - for (filter, action, rule_name, rule_id, rate_limit_config) in rules_guard.iter() { - let rule_result = filter.execute(&ctx)?; - if rule_result { - return Ok(Some(WafResult { - action: action.clone(), - rule_name: rule_name.clone(), - rule_id: rule_id.clone(), - rate_limit_config: rate_limit_config.clone(), - threat_response: threat_response.clone(), - })); - } - } - - Ok(None) - } -} - -// Global wirefilter instance for HTTP request filtering -static HTTP_FILTER: OnceLock = OnceLock::new(); - -pub fn get_global_http_filter() -> Option<&'static HttpFilter> { - HTTP_FILTER.get() -} - -pub fn set_global_http_filter(filter: HttpFilter) -> anyhow::Result<()> { - HTTP_FILTER - .set(filter) - .map_err(|_| anyhow!("Failed to initialize HTTP filter")) -} - - -/// Initialize the global config + HTTP filter from API with retry logic -pub async fn init_config(base_url: String, api_key: String) -> anyhow::Result<()> { - let mut retry_count = 0; - const MAX_RETRIES: u32 = 3; - const RETRY_DELAY_MS: u64 = 1000; - - loop { - match fetch_config(base_url.clone(), api_key.clone()).await { - Ok(config_response) => { - let filter = HttpFilter::new_from_config(&config_response.config)?; - set_global_http_filter(filter)?; - log::info!("HTTP filter initialized with {} WAF rules from config", config_response.config.waf_rules.rules.len()); - return Ok(()); - } - Err(e) => { - let error_msg = e.to_string(); - if error_msg.contains("503") && retry_count < MAX_RETRIES { - retry_count += 1; - log::warn!("Failed to fetch config for HTTP filter (attempt {}): {}. Retrying in {}ms...", retry_count, error_msg, RETRY_DELAY_MS); - tokio::time::sleep(tokio::time::Duration::from_millis(RETRY_DELAY_MS)).await; - continue; - } else { - log::error!("Failed to fetch config for HTTP filter after {} attempts: {}", retry_count + 1, error_msg); - return Err(anyhow!("Failed to initialize HTTP filter: {}", error_msg)); - } - } - } - } -} - -/// Update the global HTTP filter with new config with retry logic -pub async fn update_with_config(base_url: String, api_key: String) -> anyhow::Result<()> { - let mut retry_count = 0; - const MAX_RETRIES: u32 = 3; - const RETRY_DELAY_MS: u64 = 1000; - - loop { - match fetch_config(base_url.clone(), api_key.clone()).await { - Ok(config_response) => { - if let Some(filter) = HTTP_FILTER.get() { - filter.update_from_config(&config_response.config)?; - } else { - log::warn!("HTTP filter not initialized, cannot update"); - } - return Ok(()); - } - Err(e) => { - let error_msg = e.to_string(); - if error_msg.contains("503") && retry_count < MAX_RETRIES { - retry_count += 1; - log::warn!("Failed to fetch config for HTTP filter update (attempt {}): {}. Retrying in {}ms...", retry_count, error_msg, RETRY_DELAY_MS); - tokio::time::sleep(tokio::time::Duration::from_millis(RETRY_DELAY_MS)).await; - continue; - } else { - log::error!("Failed to fetch config for HTTP filter update after {} attempts: {}", retry_count + 1, error_msg); - return Err(anyhow!("Failed to fetch config: {}", error_msg)); - } - } - } - } -} - -/// Update the global HTTP filter using an already-fetched Config value -pub fn update_http_filter_from_config_value(config: &Config) -> anyhow::Result<()> { - if let Some(filter) = HTTP_FILTER.get() { - filter.update_from_config(config)?; - Ok(()) - } else { - log::warn!("HTTP filter not initialized, cannot update"); - Ok(()) - } -} - -/// Load WAF rules from a Vec of WafRule (for local mode) -pub async fn load_waf_rules(waf_rules: Vec) -> anyhow::Result<()> { - // Create a minimal config with just WAF rules - let config = crate::worker::config::Config { - access_rules: crate::worker::config::AccessRule { - id: "local".to_string(), - name: "Local Rules".to_string(), - description: "Local security rules".to_string(), - allow: crate::worker::config::RuleSet { - asn: vec![], - country: vec![], - ips: vec![], - }, - block: crate::worker::config::RuleSet { - asn: vec![], - country: vec![], - ips: vec![], - }, - }, - waf_rules: crate::worker::config::WafRules { rules: waf_rules }, - content_scanning: crate::content_scanning::ContentScanningConfig::default(), - created_at: chrono::Utc::now().to_rfc3339(), - updated_at: chrono::Utc::now().to_rfc3339(), - last_modified: chrono::Utc::now().to_rfc3339(), - }; - - let filter = HttpFilter::new_from_config(&config)?; - set_global_http_filter(filter)?; - Ok(()) -} - -/// Evaluate WAF rules for a Pingora request -/// This is a convenience function that converts Pingora's RequestHeader to hyper's Parts -#[cfg(feature = "proxy")] -pub async fn evaluate_waf_for_pingora_request( - req_header: &pingora_http::RequestHeader, - body_bytes: &[u8], - peer_addr: SocketAddr, -) -> Result> { - let filter = match get_global_http_filter() { - Some(f) => { - // Check if filter has any rules - let rules_count = f.rules.read().unwrap().len(); - if rules_count == 0 { - log::debug!("WAF filter initialized but has no rules loaded"); - } else { - log::debug!("WAF filter has {} rules loaded", rules_count); - } - f - } - None => { - log::debug!("WAF filter not initialized, skipping evaluation"); - return Ok(None); - } - }; - - // Convert Pingora RequestHeader to hyper::http::request::Parts - // Pingora URIs might be relative, so we need to construct a full URI - let uri_str = if req_header.uri.scheme().is_some() { - // Already an absolute URI - req_header.uri.to_string() - } else { - // Construct absolute URI from relative path - // Use http://localhost as base since we only need the path/query for WAF evaluation - format!("http://localhost{}", req_header.uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/")) - }; - - let uri = match uri_str.parse::() { - Ok(u) => u, - Err(e) => { - log::error!("WAF: Failed to parse URI '{}': {}", uri_str, e); - return Err(anyhow!("Failed to parse URI: {}", e)); - } - }; - - let mut builder = hyper::http::request::Builder::new() - .method(req_header.method.as_str()) - .uri(uri); - - // Copy headers - for (name, value) in req_header.headers.iter() { - if let Ok(name_str) = name.as_str().parse::() { - if let Ok(value_str) = value.to_str() { - builder = builder.header(name_str, value_str); - } else { - builder = builder.header(name_str, value.as_bytes()); - } - } else { - log::debug!("WAF: Failed to parse header name: {}", name.as_str()); - } - } - - let req = match builder.body(()) { - Ok(r) => r, - Err(e) => { - log::error!("WAF: Failed to build hyper request: {}", e); - return Err(anyhow!("Failed to build hyper request: {}", e)); - } - }; - let (req_parts, _) = req.into_parts(); - - log::debug!("WAF: Evaluating request - method={}, uri={}, peer={}", - req_header.method.as_str(), uri_str, peer_addr); - - match filter.should_block_request_from_parts(&req_parts, body_bytes, peer_addr).await { - Ok(result) => { - if result.is_some() { - log::debug!("WAF: Rule matched - {:?}", result); - } - Ok(result) - } - Err(e) => { - log::error!("WAF: Evaluation error: {}", e); - Err(e) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use hyper::http::request::Builder; - use std::net::Ipv4Addr; - - - #[tokio::test] - async fn test_custom_filter() -> Result<()> { - // Test a custom filter that blocks requests to specific host - let filter = HttpFilter::new("http.request.host == \"blocked.example.com\"")?; - - let req = Builder::new() - .method("GET") - .uri("http://blocked.example.com/test") - .body(())?; - let (req_parts, _) = req.into_parts(); - - let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); - let result = filter.should_block_request_from_parts(&req_parts, b"", peer_addr).await?; - if let Some(waf_result) = result { - assert_eq!(waf_result.action, WafAction::Block, "Request to blocked host should be blocked"); - } else { - panic!("Request to blocked host should be blocked"); - } - - Ok(()) - } - - #[tokio::test] - async fn test_content_scanning_integration() -> Result<()> { - // Test content scanning integration with wirefilter - let filter = HttpFilter::new("http.request.host == \"example.com\"")?; - - let req = Builder::new() - .method("POST") - .uri("http://example.com/test") - .header("content-type", "text/html") - .body(())?; - let (req_parts, _) = req.into_parts(); - - let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); - - // Test with clean content (should not be blocked by content scanning) - let clean_content = b"Clean content"; - let result = filter.should_block_request_from_parts(&req_parts, clean_content, peer_addr).await?; - - // Should be blocked by host rule, not content scanning - if let Some(waf_result) = result { - assert_eq!(waf_result.rule_name, "default", "Request to example.com should be blocked by host rule"); - } else { - panic!("Request to example.com should be blocked by host rule"); - } - - Ok(()) - } - - #[tokio::test] - async fn test_ja4h_http_version_extraction() -> Result<()> { - // Test that HTTP version is correctly extracted and used in JA4H fingerprint - let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); - - // Test HTTP/1.0 - let filter_http10 = HttpFilter::new("signal.ja4h_version == \"HTTP/1.0\"")?; - let req_http10 = Builder::new() - .method("GET") - .uri("http://example.com/test") - .version(hyper::http::Version::HTTP_10) - .body(())?; - let (req_parts_http10, _) = req_http10.into_parts(); - let result_http10 = filter_http10.should_block_request_from_parts(&req_parts_http10, b"", peer_addr).await?; - if let Some(waf_result) = result_http10 { - assert_eq!(waf_result.action, WafAction::Block, "HTTP/1.0 request should match version check"); - } else { - panic!("HTTP/1.0 request should match version check"); - } - - // Test HTTP/1.1 - let filter_http11 = HttpFilter::new("signal.ja4h_version == \"HTTP/1.1\"")?; - let req_http11 = Builder::new() - .method("GET") - .uri("http://example.com/test") - .version(hyper::http::Version::HTTP_11) - .body(())?; - let (req_parts_http11, _) = req_http11.into_parts(); - let result_http11 = filter_http11.should_block_request_from_parts(&req_parts_http11, b"", peer_addr).await?; - if let Some(waf_result) = result_http11 { - assert_eq!(waf_result.action, WafAction::Block, "HTTP/1.1 request should match version check"); - } else { - panic!("HTTP/1.1 request should match version check"); - } - - // Test HTTP/2.0 - let filter_http2 = HttpFilter::new("signal.ja4h_version == \"HTTP/2.0\"")?; - let req_http2 = Builder::new() - .method("GET") - .uri("http://example.com/test") - .version(hyper::http::Version::HTTP_2) - .body(())?; - let (req_parts_http2, _) = req_http2.into_parts(); - let result_http2 = filter_http2.should_block_request_from_parts(&req_parts_http2, b"", peer_addr).await?; - if let Some(waf_result) = result_http2 { - assert_eq!(waf_result.action, WafAction::Block, "HTTP/2.0 request should match version check"); - } else { - panic!("HTTP/2.0 request should match version check"); - } - - // Test that wrong version doesn't match - let filter_wrong_version = HttpFilter::new("signal.ja4h_version == \"HTTP/1.0\"")?; - let req_wrong = Builder::new() - .method("GET") - .uri("http://example.com/test") - .version(hyper::http::Version::HTTP_11) - .body(())?; - let (req_parts_wrong, _) = req_wrong.into_parts(); - let result_wrong = filter_wrong_version.should_block_request_from_parts(&req_parts_wrong, b"", peer_addr).await?; - assert!(result_wrong.is_none(), "HTTP/1.1 request should not match HTTP/1.0 version check"); - - Ok(()) - } - - #[tokio::test] - async fn test_ja4h_fingerprint_with_different_versions() -> Result<()> { - // Test that JA4H fingerprint is correctly generated with different HTTP versions - let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); - - // Test that JA4H fingerprint starts with correct version code for HTTP/1.0 (should be "10") - let filter_http10 = HttpFilter::new("starts_with(signal.ja4h, \"ge10\")")?; - let req_http10 = Builder::new() - .method("GET") - .uri("http://example.com/test") - .version(hyper::http::Version::HTTP_10) - .body(())?; - let (req_parts_http10, _) = req_http10.into_parts(); - let result_http10 = filter_http10.should_block_request_from_parts(&req_parts_http10, b"", peer_addr).await?; - assert!(result_http10.is_some(), "HTTP/1.0 request should generate JA4H starting with 'ge10'"); - - // Test that JA4H fingerprint starts with correct version code for HTTP/1.1 (should be "11") - let filter_http11 = HttpFilter::new("starts_with(signal.ja4h, \"ge11\")")?; - let req_http11 = Builder::new() - .method("GET") - .uri("http://example.com/test") - .version(hyper::http::Version::HTTP_11) - .body(())?; - let (req_parts_http11, _) = req_http11.into_parts(); - let result_http11 = filter_http11.should_block_request_from_parts(&req_parts_http11, b"", peer_addr).await?; - assert!(result_http11.is_some(), "HTTP/1.1 request should generate JA4H starting with 'ge11'"); - - // Test that JA4H fingerprint starts with correct version code for HTTP/2.0 (should be "20") - let filter_http2 = HttpFilter::new("starts_with(signal.ja4h, \"ge20\")")?; - let req_http2 = Builder::new() - .method("GET") - .uri("http://example.com/test") - .version(hyper::http::Version::HTTP_2) - .body(())?; - let (req_parts_http2, _) = req_http2.into_parts(); - let result_http2 = filter_http2.should_block_request_from_parts(&req_parts_http2, b"", peer_addr).await?; - assert!(result_http2.is_some(), "HTTP/2.0 request should generate JA4H starting with 'ge20'"); - - Ok(()) - } -} +use std::collections::HashSet; +use std::net::SocketAddr; +use std::sync::{Arc, RwLock, OnceLock}; + +use anyhow::Result; +use sha2::{Digest, Sha256}; +use wirefilter::{ExecutionContext, Scheme, TypedArray, TypedMap}; +use crate::worker::config::{Config, fetch_config}; +use crate::threat; +use anyhow::anyhow; + +/// String interner for WAF rule expressions +/// This deduplicates expression strings to minimize memory usage. +/// Each unique expression is stored once and reused across rule updates. +/// Note: Interned strings are leaked to satisfy wirefilter's 'static lifetime requirement, +/// but deduplication ensures each unique expression is only leaked once. +struct ExpressionInterner { + // Set of already interned expression hashes to avoid re-leaking + interned: HashSet, +} + +impl ExpressionInterner { + fn new() -> Self { + Self { + interned: HashSet::new(), + } + } + + /// Intern an expression string, returning a &'static str reference + /// If the expression was already interned, recomputes the hash but doesn't re-leak + /// (The hash check prevents re-leaking the same expression) + fn intern(&mut self, expr: &str) -> &'static str { + use std::hash::{Hash, Hasher}; + use std::collections::hash_map::DefaultHasher; + + let mut hasher = DefaultHasher::new(); + expr.hash(&mut hasher); + let hash = hasher.finish(); + + if self.interned.contains(&hash) { + // Already interned, but we still need to return a 'static reference + // We must leak again since we don't store the actual pointers + // This is a trade-off: we track hashes to log warnings, but still leak + log::trace!("Re-using previously interned expression (hash collision or same expr)"); + } + + self.interned.insert(hash); + Box::leak(expr.to_string().into_boxed_str()) + } + + /// Get the count of unique expressions interned + fn count(&self) -> usize { + self.interned.len() + } +} + +/// Global expression interner to deduplicate WAF rule expressions across updates +static EXPRESSION_INTERNER: OnceLock> = OnceLock::new(); + +fn get_expression_interner() -> &'static RwLock { + EXPRESSION_INTERNER.get_or_init(|| RwLock::new(ExpressionInterner::new())) +} + +/// Intern a WAF rule expression, deduplicating across all rules +fn intern_expression(expr: &str) -> &'static str { + let mut interner = get_expression_interner().write().unwrap(); + interner.intern(expr) +} + +/// Get the count of interned expressions (for diagnostics) +pub fn get_interned_expression_count() -> usize { + if let Some(interner) = EXPRESSION_INTERNER.get() { + interner.read().unwrap().count() + } else { + 0 + } +} + +/// WAF action types +#[derive(Debug, Clone, PartialEq)] +pub enum WafAction { + Block, + Challenge, + RateLimit, + Allow, +} + +impl WafAction { + pub fn from_str(action: &str) -> Self { + match action.to_lowercase().as_str() { + "block" => WafAction::Block, + "challenge" => WafAction::Challenge, + "ratelimit" => WafAction::RateLimit, + _ => WafAction::Allow, + } + } +} + +/// WAF rule evaluation result +#[derive(Debug, Clone)] +pub struct WafResult { + pub action: WafAction, + pub rule_name: String, + pub rule_id: String, + pub rate_limit_config: Option, + pub threat_response: Option, +} + +/// Wirefilter-based HTTP request filtering engine +pub struct HttpFilter { + scheme: Arc, + rules: Arc)>>>, // (filter, action, name, id, rate_limit_config) + rules_hash: Arc>>, +} + +impl HttpFilter { + /// Create the wirefilter scheme with HTTP request fields + fn create_scheme() -> Scheme { + let mut builder = Scheme! { + http.request.method: Bytes, + http.request.scheme: Bytes, + http.request.host: Bytes, + http.request.port: Int, + http.request.path: Bytes, + http.request.uri: Bytes, + http.request.query: Bytes, + http.request.user_agent: Bytes, + http.request.content_type: Bytes, + http.request.content_length: Int, + http.request.body: Bytes, + http.request.body_sha256: Bytes, + http.request.headers: Map(Array(Bytes)), + ip.src: Ip, + ip.src.country: Bytes, + ip.src.asn: Int, + ip.src.asn_org: Bytes, + ip.src.asn_country: Bytes, + threat.score: Int, + threat.advice: Bytes, + signal.ja4: Bytes, + signal.ja4_raw: Bytes, + signal.ja4_unsorted: Bytes, + signal.ja4_raw_unsorted: Bytes, + signal.tls_version: Bytes, + signal.cipher_suite: Bytes, + signal.sni: Bytes, + signal.alpn: Bytes, + signal.ja4h: Bytes, + signal.ja4h_method: Bytes, + signal.ja4h_version: Bytes, + signal.ja4h_has_cookie: Int, + signal.ja4h_has_referer: Int, + signal.ja4h_header_count: Int, + signal.ja4h_language: Bytes, + signal.ja4t: Bytes, + signal.ja4t_window_size: Int, + signal.ja4t_ttl: Int, + signal.ja4t_mss: Int, + signal.ja4t_window_scale: Int, + signal.ja4l_client: Bytes, + signal.ja4l_server: Bytes, + signal.ja4l_syn_time: Int, + signal.ja4l_synack_time: Int, + signal.ja4l_ack_time: Int, + signal.ja4l_ttl_client: Int, + signal.ja4l_ttl_server: Int, + signal.ja4s: Bytes, + signal.ja4s_proto: Bytes, + signal.ja4s_version: Bytes, + signal.ja4s_cipher: Int, + signal.ja4s_alpn: Bytes, + signal.ja4x: Bytes, + signal.ja4x_issuer_rdns: Bytes, + signal.ja4x_subject_rdns: Bytes, + signal.ja4x_extensions: Bytes, + }; + + // Register functions used in Cloudflare-style expressions + builder.add_function("any", wirefilter::AnyFunction::default()).unwrap(); + builder.add_function("all", wirefilter::AllFunction::default()).unwrap(); + + builder.add_function("cidr", wirefilter::CIDRFunction::default()).unwrap(); + builder.add_function("concat", wirefilter::ConcatFunction::default()).unwrap(); + builder.add_function("decode_base64", wirefilter::DecodeBase64Function::default()).unwrap(); + builder.add_function("ends_with", wirefilter::EndsWithFunction::default()).unwrap(); + builder.add_function("json_lookup_integer", wirefilter::JsonLookupIntegerFunction::default()).unwrap(); + builder.add_function("json_lookup_string", wirefilter::JsonLookupStringFunction::default()).unwrap(); + builder.add_function("len", wirefilter::LenFunction::default()).unwrap(); + builder.add_function("lower", wirefilter::LowerFunction::default()).unwrap(); + builder.add_function("remove_bytes", wirefilter::RemoveBytesFunction::default()).unwrap(); + builder.add_function("remove_query_args", wirefilter::RemoveQueryArgsFunction::default()).unwrap(); + builder.add_function("starts_with", wirefilter::StartsWithFunction::default()).unwrap(); + builder.add_function("substring", wirefilter::SubstringFunction::default()).unwrap(); + builder.add_function("to_string", wirefilter::ToStringFunction::default()).unwrap(); + builder.add_function("upper", wirefilter::UpperFunction::default()).unwrap(); + builder.add_function("url_decode", wirefilter::UrlDecodeFunction::default()).unwrap(); + builder.add_function("uuid4", wirefilter::UUID4Function::default()).unwrap(); + builder.add_function("wildcard_replace", wirefilter::WildcardReplaceFunction::default()).unwrap(); + + + builder.build() + } + + /// Create a new HTTP filter with the given filter expression (static version) + pub fn new(filter_expr: &'static str) -> Result { + // Create the scheme with HTTP request fields + let scheme = Arc::new(Self::create_scheme()); + + // Parse the filter expression + let ast = scheme.parse(filter_expr)?; + + // Compile the filter + let filter = ast.compile(); + + Ok(Self { + scheme, + rules: Arc::new(RwLock::new(vec![ + (filter, WafAction::Block, "default".to_string(), "default".to_string(), None) + ])), + rules_hash: Arc::new(RwLock::new(None)), + }) + } + + /// Create a new HTTP filter from config WAF rules + pub fn new_from_config(config: &Config) -> Result { + // Create the scheme with HTTP request fields + let scheme = Arc::new(Self::create_scheme()); + + if config.waf_rules.rules.is_empty() { + // If no WAF rules, create a default filter that allows all + return Ok(Self { + scheme, + rules: Arc::new(RwLock::new(vec![])), + rules_hash: Arc::new(RwLock::new(Some(Self::compute_rules_hash("")))), + }); + } + + // Validate and compile individual WAF rules + let mut compiled_rules = Vec::new(); + let mut rules_hash_input = String::new(); + + for rule in &config.waf_rules.rules { + // Basic validation - check if expression is not empty + if rule.expression.trim().is_empty() { + log::warn!("Skipping empty WAF rule expression for rule '{}'", rule.name); + continue; + } + + // Try to parse the expression to validate it + if let Err(error) = scheme.parse(&rule.expression) { + log::warn!("Invalid WAF rule expression for rule '{}': {}: {}", rule.name, rule.expression, error); + continue; + } + + // Compile the rule using interned expression to minimize memory leakage + let expression = intern_expression(&rule.expression); + let ast = scheme.parse(expression)?; + let filter = ast.compile(); + let action = WafAction::from_str(&rule.action); + + // Parse rate limit config if action is RateLimit + let rate_limit_config = if action == WafAction::RateLimit { + rule.config.as_ref().and_then(|cfg| { + match crate::worker::config::RateLimitConfig::from_json(cfg) { + Ok(config) => { + log::debug!("Parsed rate limit config for rule {}: period={}, requests={}", + rule.id, config.period, config.requests); + Some(config) + } + Err(e) => { + log::error!("Failed to parse rate limit config for rule {}: {}. Config JSON: {}", + rule.id, e, serde_json::to_string(cfg).unwrap_or_else(|_| "invalid json".to_string())); + None + } + } + }) + } else { + None + }; + + compiled_rules.push((filter, action, rule.name.clone(), rule.id.clone(), rate_limit_config)); + rules_hash_input.push_str(&format!("{}:{}:{};", rule.id, rule.action, rule.expression)); + } + + if compiled_rules.is_empty() { + log::warn!("No valid WAF rules found, using default filter that allows all"); + return Ok(Self { + scheme, + rules: Arc::new(RwLock::new(vec![])), + rules_hash: Arc::new(RwLock::new(Some(Self::compute_rules_hash("")))), + }); + } + + let hash = Self::compute_rules_hash(&rules_hash_input); + log::debug!("WAF expression interner now has {} unique expressions", get_interned_expression_count()); + Ok(Self { + scheme, + rules: Arc::new(RwLock::new(compiled_rules)), + rules_hash: Arc::new(RwLock::new(Some(hash))), + }) + } + + /// Update the filter with new WAF rules from config + pub fn update_from_config(&self, config: &Config) -> Result<()> { + // Validate and compile individual WAF rules + let mut compiled_rules = Vec::new(); + let mut rules_hash_input = String::new(); + + for rule in &config.waf_rules.rules { + // Basic validation - check if expression is not empty + if rule.expression.trim().is_empty() { + log::warn!("Skipping empty WAF rule expression for rule '{}'", rule.name); + continue; + } + + // Try to parse the expression to validate it + if let Err(error) = self.scheme.parse(&rule.expression) { + log::warn!("Invalid WAF rule expression for rule '{}': {}: {}", rule.name, rule.expression, error); + continue; + } + + // Compile the rule using interned expression to minimize memory leakage + let expression = intern_expression(&rule.expression); + let ast = self.scheme.parse(expression)?; + let filter = ast.compile(); + let action = WafAction::from_str(&rule.action); + + // Parse rate limit config if action is RateLimit + let rate_limit_config = if action == WafAction::RateLimit { + rule.config.as_ref().and_then(|cfg| { + match crate::worker::config::RateLimitConfig::from_json(cfg) { + Ok(config) => { + log::debug!("Parsed rate limit config for rule {}: period={}, requests={}", + rule.id, config.period, config.requests); + Some(config) + } + Err(e) => { + log::error!("Failed to parse rate limit config for rule {}: {}. Config JSON: {}", + rule.id, e, serde_json::to_string(cfg).unwrap_or_else(|_| "invalid json".to_string())); + None + } + } + }) + } else { + None + }; + + compiled_rules.push((filter, action, rule.name.clone(), rule.id.clone(), rate_limit_config)); + rules_hash_input.push_str(&format!("{}:{}:{};", rule.id, rule.action, rule.expression)); + } + + // Compute hash and skip update if unchanged + let new_hash = Self::compute_rules_hash(&rules_hash_input); + if let Some(prev) = self.rules_hash.read().unwrap().as_ref() { + if prev == &new_hash { + log::debug!("HTTP filter WAF rules unchanged; skipping update"); + return Ok(()); + } + } + + let rules_count = compiled_rules.len(); + *self.rules.write().unwrap() = compiled_rules; + *self.rules_hash.write().unwrap() = Some(new_hash); + + log::info!("HTTP filter updated with {} WAF rules from config (expression interner: {} unique expressions)", + rules_count, get_interned_expression_count()); + + Ok(()) + } + + fn compute_rules_hash(expr: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(expr.as_bytes()); + hex::encode(hasher.finalize()) + } + + /// Get the current filter expression (for debugging) + pub fn get_current_expression(&self) -> String { + // This is a simplified version - in practice you might want to store the original expression + "dynamic_filter_from_config".to_string() + } + + /// Check if the given HTTP request should be blocked using request parts and body bytes + pub async fn should_block_request_from_parts( + &self, + req_parts: &hyper::http::request::Parts, + body_bytes: &[u8], + peer_addr: SocketAddr, + ) -> Result> { + // Create execution context + let mut ctx = ExecutionContext::new(&self.scheme); + + // Extract request information + let method = req_parts.method.as_str(); + let uri = &req_parts.uri; + let scheme = uri.scheme().map(|s| s.as_str()).unwrap_or("http"); + let host = uri.host().unwrap_or("").to_string(); + let port = uri.port_u16().unwrap_or(if scheme == "https" { 443 } else { 80 }); + let path = uri.path().to_string(); + let full_uri = uri.to_string(); + let query = uri.query().unwrap_or("").to_string(); + + // Extract headers + let user_agent = req_parts + .headers + .get("user-agent") + .and_then(|h| h.to_str().ok()) + .unwrap_or("") + .to_string(); + + let content_type = req_parts + .headers + .get("content-type") + .and_then(|h| h.to_str().ok()) + .unwrap_or("") + .to_string(); + + // Get content length + let content_length = req_parts + .headers + .get("content-length") + .and_then(|h| h.to_str().ok()) + .and_then(|s| s.parse::().ok()) + .unwrap_or(body_bytes.len() as i64); + + // Process request body + let body_text = String::from_utf8_lossy(body_bytes).to_string(); + + // Calculate body SHA256 + let mut hasher = Sha256::new(); + hasher.update(body_bytes); + let body_sha256_hex = hex::encode(hasher.finalize()); + + // Set field values in execution context + ctx.set_field_value( + self.scheme.get_field("http.request.method").unwrap(), + method, + )?; + ctx.set_field_value( + self.scheme.get_field("http.request.scheme").unwrap(), + scheme, + )?; + ctx.set_field_value( + self.scheme.get_field("http.request.host").unwrap(), + host, + )?; + ctx.set_field_value( + self.scheme.get_field("http.request.port").unwrap(), + port as i64, + )?; + ctx.set_field_value( + self.scheme.get_field("http.request.path").unwrap(), + path, + )?; + ctx.set_field_value( + self.scheme.get_field("http.request.uri").unwrap(), + full_uri, + )?; + ctx.set_field_value( + self.scheme.get_field("http.request.query").unwrap(), + query, + )?; + ctx.set_field_value( + self.scheme.get_field("http.request.user_agent").unwrap(), + user_agent, + )?; + ctx.set_field_value( + self.scheme.get_field("http.request.content_type").unwrap(), + content_type, + )?; + ctx.set_field_value( + self.scheme.get_field("http.request.headers").unwrap(), + { + let mut headers_map: TypedMap<'_, TypedArray<'_, &[u8]>> = TypedMap::new(); + for (name, value) in req_parts.headers.iter() { + let key = name.as_str().to_ascii_lowercase().into_bytes().into_boxed_slice(); + let entry = headers_map.get_or_insert(key, TypedArray::new()); + match value.to_str() { + Ok(s) => entry.push(s.as_bytes()), + Err(_) => entry.push(value.as_bytes()), + } + } + headers_map + }, + )?; + ctx.set_field_value( + self.scheme.get_field("http.request.content_length").unwrap(), + content_length, + )?; + ctx.set_field_value( + self.scheme.get_field("http.request.body").unwrap(), + body_text, + )?; + ctx.set_field_value( + self.scheme.get_field("http.request.body_sha256").unwrap(), + body_sha256_hex, + )?; + ctx.set_field_value( + self.scheme.get_field("ip.src").unwrap(), + peer_addr.ip(), + )?; + + // Fetch threat intelligence data for the source IP + // Fetch full threat response for access logging, and WAF fields for rule evaluation + log::info!("🔍 [WAF] Looking up GeoIP for: {}", peer_addr.ip()); + let threat_response = threat::get_threat_intel(&peer_addr.ip().to_string()).await.ok().flatten(); + let _threat_fields = if let Some(ref threat_resp) = threat_response { + let waf_fields = threat::WafFields::from(threat_resp); + log::info!("🔍 [WAF] GeoIP Result: Country='{}', ASN={}, ThreatScore={}", + waf_fields.ip_src_country, waf_fields.ip_src_asn, waf_fields.threat_score); + // Set threat intelligence fields + ctx.set_field_value( + self.scheme.get_field("ip.src.country").unwrap(), + waf_fields.ip_src_country.clone(), + )?; + ctx.set_field_value( + self.scheme.get_field("ip.src.asn").unwrap(), + waf_fields.ip_src_asn as i64, + )?; + ctx.set_field_value( + self.scheme.get_field("ip.src.asn_org").unwrap(), + waf_fields.ip_src_asn_org.clone(), + )?; + ctx.set_field_value( + self.scheme.get_field("ip.src.asn_country").unwrap(), + waf_fields.ip_src_asn_country.clone(), + )?; + ctx.set_field_value( + self.scheme.get_field("threat.score").unwrap(), + waf_fields.threat_score as i64, + )?; + ctx.set_field_value( + self.scheme.get_field("threat.advice").unwrap(), + waf_fields.threat_advice.clone(), + )?; + Some(waf_fields) + } else { + // No threat data found, set default values + log::warn!("🔍 [WAF] No GeoIP data found for {}, setting empty country", peer_addr.ip()); + ctx.set_field_value( + self.scheme.get_field("ip.src.country").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("ip.src.asn").unwrap(), + 0i64, + )?; + ctx.set_field_value( + self.scheme.get_field("ip.src.asn_org").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("ip.src.asn_country").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("threat.score").unwrap(), + 0i64, + )?; + ctx.set_field_value( + self.scheme.get_field("threat.advice").unwrap(), + "", + )?; + None + }; + + // Extract HTTP version + let http_version = format!("{:?}", req_parts.version); + + // Generate JA4H fingerprint from HTTP request (available now) + let ja4h_fp = crate::ja4_plus::Ja4hFingerprint::from_http_request( + method, + &http_version, + &req_parts.headers, + ); + + // Set default empty values for all signal (JA4) fields + // These fields will be populated when JA4 data is available + ctx.set_field_value( + self.scheme.get_field("signal.ja4").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4_raw").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4_unsorted").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4_raw_unsorted").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.tls_version").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.cipher_suite").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.sni").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.alpn").unwrap(), + "", + )?; + // Populate JA4H fields from generated fingerprint + ctx.set_field_value( + self.scheme.get_field("signal.ja4h").unwrap(), + ja4h_fp.fingerprint.clone(), + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4h_method").unwrap(), + ja4h_fp.method.clone(), + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4h_version").unwrap(), + ja4h_fp.version.clone(), + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4h_has_cookie").unwrap(), + if ja4h_fp.has_cookie { 1i64 } else { 0i64 }, + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4h_has_referer").unwrap(), + if ja4h_fp.has_referer { 1i64 } else { 0i64 }, + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4h_header_count").unwrap(), + ja4h_fp.header_count as i64, + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4h_language").unwrap(), + ja4h_fp.language.clone(), + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4t").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4t_window_size").unwrap(), + 0i64, + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4t_ttl").unwrap(), + 0i64, + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4t_mss").unwrap(), + 0i64, + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4t_window_scale").unwrap(), + 0i64, + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4l_client").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4l_server").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4l_syn_time").unwrap(), + 0i64, + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4l_synack_time").unwrap(), + 0i64, + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4l_ack_time").unwrap(), + 0i64, + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4l_ttl_client").unwrap(), + 0i64, + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4l_ttl_server").unwrap(), + 0i64, + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4s").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4s_proto").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4s_version").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4s_cipher").unwrap(), + 0i64, + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4s_alpn").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4x").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4x_issuer_rdns").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4x_subject_rdns").unwrap(), + "", + )?; + ctx.set_field_value( + self.scheme.get_field("signal.ja4x_extensions").unwrap(), + "", + )?; + + // Execute each rule individually and return the first match + let rules_guard = self.rules.read().unwrap(); + for (filter, action, rule_name, rule_id, rate_limit_config) in rules_guard.iter() { + let rule_result = filter.execute(&ctx)?; + if rule_result { + return Ok(Some(WafResult { + action: action.clone(), + rule_name: rule_name.clone(), + rule_id: rule_id.clone(), + rate_limit_config: rate_limit_config.clone(), + threat_response: threat_response.clone(), + })); + } + } + + Ok(None) + } +} + +// Global wirefilter instance for HTTP request filtering +static HTTP_FILTER: OnceLock = OnceLock::new(); + +pub fn get_global_http_filter() -> Option<&'static HttpFilter> { + HTTP_FILTER.get() +} + +pub fn set_global_http_filter(filter: HttpFilter) -> anyhow::Result<()> { + HTTP_FILTER + .set(filter) + .map_err(|_| anyhow!("Failed to initialize HTTP filter")) +} + + +/// Initialize the global config + HTTP filter from API with retry logic +pub async fn init_config(base_url: String, api_key: String) -> anyhow::Result<()> { + let mut retry_count = 0; + const MAX_RETRIES: u32 = 3; + const RETRY_DELAY_MS: u64 = 1000; + + loop { + match fetch_config(base_url.clone(), api_key.clone()).await { + Ok(config_response) => { + let filter = HttpFilter::new_from_config(&config_response.config)?; + set_global_http_filter(filter)?; + log::info!("HTTP filter initialized with {} WAF rules from config", config_response.config.waf_rules.rules.len()); + return Ok(()); + } + Err(e) => { + let error_msg = e.to_string(); + if error_msg.contains("503") && retry_count < MAX_RETRIES { + retry_count += 1; + log::warn!("Failed to fetch config for HTTP filter (attempt {}): {}. Retrying in {}ms...", retry_count, error_msg, RETRY_DELAY_MS); + tokio::time::sleep(tokio::time::Duration::from_millis(RETRY_DELAY_MS)).await; + continue; + } else { + log::error!("Failed to fetch config for HTTP filter after {} attempts: {}", retry_count + 1, error_msg); + return Err(anyhow!("Failed to initialize HTTP filter: {}", error_msg)); + } + } + } + } +} + +/// Update the global HTTP filter with new config with retry logic +pub async fn update_with_config(base_url: String, api_key: String) -> anyhow::Result<()> { + let mut retry_count = 0; + const MAX_RETRIES: u32 = 3; + const RETRY_DELAY_MS: u64 = 1000; + + loop { + match fetch_config(base_url.clone(), api_key.clone()).await { + Ok(config_response) => { + if let Some(filter) = HTTP_FILTER.get() { + filter.update_from_config(&config_response.config)?; + } else { + log::warn!("HTTP filter not initialized, cannot update"); + } + return Ok(()); + } + Err(e) => { + let error_msg = e.to_string(); + if error_msg.contains("503") && retry_count < MAX_RETRIES { + retry_count += 1; + log::warn!("Failed to fetch config for HTTP filter update (attempt {}): {}. Retrying in {}ms...", retry_count, error_msg, RETRY_DELAY_MS); + tokio::time::sleep(tokio::time::Duration::from_millis(RETRY_DELAY_MS)).await; + continue; + } else { + log::error!("Failed to fetch config for HTTP filter update after {} attempts: {}", retry_count + 1, error_msg); + return Err(anyhow!("Failed to fetch config: {}", error_msg)); + } + } + } + } +} + +/// Update the global HTTP filter using an already-fetched Config value +pub fn update_http_filter_from_config_value(config: &Config) -> anyhow::Result<()> { + if let Some(filter) = HTTP_FILTER.get() { + filter.update_from_config(config)?; + Ok(()) + } else { + log::warn!("HTTP filter not initialized, cannot update"); + Ok(()) + } +} + +/// Load WAF rules from a Vec of WafRule (for local mode) +pub async fn load_waf_rules(waf_rules: Vec) -> anyhow::Result<()> { + // Create a minimal config with just WAF rules + let config = crate::worker::config::Config { + access_rules: crate::worker::config::AccessRule { + id: "local".to_string(), + name: "Local Rules".to_string(), + description: "Local security rules".to_string(), + allow: crate::worker::config::RuleSet { + asn: vec![], + country: vec![], + ips: vec![], + }, + block: crate::worker::config::RuleSet { + asn: vec![], + country: vec![], + ips: vec![], + }, + }, + waf_rules: crate::worker::config::WafRules { rules: waf_rules }, + content_scanning: crate::content_scanning::ContentScanningConfig::default(), + created_at: chrono::Utc::now().to_rfc3339(), + updated_at: chrono::Utc::now().to_rfc3339(), + last_modified: chrono::Utc::now().to_rfc3339(), + }; + + let filter = HttpFilter::new_from_config(&config)?; + set_global_http_filter(filter)?; + Ok(()) +} + +/// Evaluate WAF rules for a Pingora request +/// This is a convenience function that converts Pingora's RequestHeader to hyper's Parts +#[cfg(feature = "proxy")] +pub async fn evaluate_waf_for_pingora_request( + req_header: &pingora_http::RequestHeader, + body_bytes: &[u8], + peer_addr: SocketAddr, +) -> Result> { + let filter = match get_global_http_filter() { + Some(f) => { + // Check if filter has any rules + let rules_count = f.rules.read().unwrap().len(); + if rules_count == 0 { + log::debug!("WAF filter initialized but has no rules loaded"); + } else { + log::debug!("WAF filter has {} rules loaded", rules_count); + } + f + } + None => { + log::debug!("WAF filter not initialized, skipping evaluation"); + return Ok(None); + } + }; + + // Convert Pingora RequestHeader to hyper::http::request::Parts + // Pingora URIs might be relative, so we need to construct a full URI + let uri_str = if req_header.uri.scheme().is_some() { + // Already an absolute URI + req_header.uri.to_string() + } else { + // Construct absolute URI from relative path + // Use http://localhost as base since we only need the path/query for WAF evaluation + format!("http://localhost{}", req_header.uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/")) + }; + + let uri = match uri_str.parse::() { + Ok(u) => u, + Err(e) => { + log::error!("WAF: Failed to parse URI '{}': {}", uri_str, e); + return Err(anyhow!("Failed to parse URI: {}", e)); + } + }; + + let mut builder = hyper::http::request::Builder::new() + .method(req_header.method.as_str()) + .uri(uri); + + // Copy headers + for (name, value) in req_header.headers.iter() { + if let Ok(name_str) = name.as_str().parse::() { + if let Ok(value_str) = value.to_str() { + builder = builder.header(name_str, value_str); + } else { + builder = builder.header(name_str, value.as_bytes()); + } + } else { + log::debug!("WAF: Failed to parse header name: {}", name.as_str()); + } + } + + let req = match builder.body(()) { + Ok(r) => r, + Err(e) => { + log::error!("WAF: Failed to build hyper request: {}", e); + return Err(anyhow!("Failed to build hyper request: {}", e)); + } + }; + let (req_parts, _) = req.into_parts(); + + log::debug!("WAF: Evaluating request - method={}, uri={}, peer={}", + req_header.method.as_str(), uri_str, peer_addr); + + match filter.should_block_request_from_parts(&req_parts, body_bytes, peer_addr).await { + Ok(result) => { + if result.is_some() { + log::debug!("WAF: Rule matched - {:?}", result); + } + Ok(result) + } + Err(e) => { + log::error!("WAF: Evaluation error: {}", e); + Err(e) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use hyper::http::request::Builder; + use std::net::Ipv4Addr; + + + #[tokio::test] + async fn test_custom_filter() -> Result<()> { + // Test a custom filter that blocks requests to specific host + let filter = HttpFilter::new("http.request.host == \"blocked.example.com\"")?; + + let req = Builder::new() + .method("GET") + .uri("http://blocked.example.com/test") + .body(())?; + let (req_parts, _) = req.into_parts(); + + let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); + let result = filter.should_block_request_from_parts(&req_parts, b"", peer_addr).await?; + if let Some(waf_result) = result { + assert_eq!(waf_result.action, WafAction::Block, "Request to blocked host should be blocked"); + } else { + panic!("Request to blocked host should be blocked"); + } + + Ok(()) + } + + #[tokio::test] + async fn test_content_scanning_integration() -> Result<()> { + // Test content scanning integration with wirefilter + let filter = HttpFilter::new("http.request.host == \"example.com\"")?; + + let req = Builder::new() + .method("POST") + .uri("http://example.com/test") + .header("content-type", "text/html") + .body(())?; + let (req_parts, _) = req.into_parts(); + + let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); + + // Test with clean content (should not be blocked by content scanning) + let clean_content = b"Clean content"; + let result = filter.should_block_request_from_parts(&req_parts, clean_content, peer_addr).await?; + + // Should be blocked by host rule, not content scanning + if let Some(waf_result) = result { + assert_eq!(waf_result.rule_name, "default", "Request to example.com should be blocked by host rule"); + } else { + panic!("Request to example.com should be blocked by host rule"); + } + + Ok(()) + } + + #[tokio::test] + async fn test_ja4h_http_version_extraction() -> Result<()> { + // Test that HTTP version is correctly extracted and used in JA4H fingerprint + let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); + + // Test HTTP/1.0 + let filter_http10 = HttpFilter::new("signal.ja4h_version == \"HTTP/1.0\"")?; + let req_http10 = Builder::new() + .method("GET") + .uri("http://example.com/test") + .version(hyper::http::Version::HTTP_10) + .body(())?; + let (req_parts_http10, _) = req_http10.into_parts(); + let result_http10 = filter_http10.should_block_request_from_parts(&req_parts_http10, b"", peer_addr).await?; + if let Some(waf_result) = result_http10 { + assert_eq!(waf_result.action, WafAction::Block, "HTTP/1.0 request should match version check"); + } else { + panic!("HTTP/1.0 request should match version check"); + } + + // Test HTTP/1.1 + let filter_http11 = HttpFilter::new("signal.ja4h_version == \"HTTP/1.1\"")?; + let req_http11 = Builder::new() + .method("GET") + .uri("http://example.com/test") + .version(hyper::http::Version::HTTP_11) + .body(())?; + let (req_parts_http11, _) = req_http11.into_parts(); + let result_http11 = filter_http11.should_block_request_from_parts(&req_parts_http11, b"", peer_addr).await?; + if let Some(waf_result) = result_http11 { + assert_eq!(waf_result.action, WafAction::Block, "HTTP/1.1 request should match version check"); + } else { + panic!("HTTP/1.1 request should match version check"); + } + + // Test HTTP/2.0 + let filter_http2 = HttpFilter::new("signal.ja4h_version == \"HTTP/2.0\"")?; + let req_http2 = Builder::new() + .method("GET") + .uri("http://example.com/test") + .version(hyper::http::Version::HTTP_2) + .body(())?; + let (req_parts_http2, _) = req_http2.into_parts(); + let result_http2 = filter_http2.should_block_request_from_parts(&req_parts_http2, b"", peer_addr).await?; + if let Some(waf_result) = result_http2 { + assert_eq!(waf_result.action, WafAction::Block, "HTTP/2.0 request should match version check"); + } else { + panic!("HTTP/2.0 request should match version check"); + } + + // Test that wrong version doesn't match + let filter_wrong_version = HttpFilter::new("signal.ja4h_version == \"HTTP/1.0\"")?; + let req_wrong = Builder::new() + .method("GET") + .uri("http://example.com/test") + .version(hyper::http::Version::HTTP_11) + .body(())?; + let (req_parts_wrong, _) = req_wrong.into_parts(); + let result_wrong = filter_wrong_version.should_block_request_from_parts(&req_parts_wrong, b"", peer_addr).await?; + assert!(result_wrong.is_none(), "HTTP/1.1 request should not match HTTP/1.0 version check"); + + Ok(()) + } + + #[tokio::test] + async fn test_ja4h_fingerprint_with_different_versions() -> Result<()> { + // Test that JA4H fingerprint is correctly generated with different HTTP versions + let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); + + // Test that JA4H fingerprint starts with correct version code for HTTP/1.0 (should be "10") + let filter_http10 = HttpFilter::new("starts_with(signal.ja4h, \"ge10\")")?; + let req_http10 = Builder::new() + .method("GET") + .uri("http://example.com/test") + .version(hyper::http::Version::HTTP_10) + .body(())?; + let (req_parts_http10, _) = req_http10.into_parts(); + let result_http10 = filter_http10.should_block_request_from_parts(&req_parts_http10, b"", peer_addr).await?; + assert!(result_http10.is_some(), "HTTP/1.0 request should generate JA4H starting with 'ge10'"); + + // Test that JA4H fingerprint starts with correct version code for HTTP/1.1 (should be "11") + let filter_http11 = HttpFilter::new("starts_with(signal.ja4h, \"ge11\")")?; + let req_http11 = Builder::new() + .method("GET") + .uri("http://example.com/test") + .version(hyper::http::Version::HTTP_11) + .body(())?; + let (req_parts_http11, _) = req_http11.into_parts(); + let result_http11 = filter_http11.should_block_request_from_parts(&req_parts_http11, b"", peer_addr).await?; + assert!(result_http11.is_some(), "HTTP/1.1 request should generate JA4H starting with 'ge11'"); + + // Test that JA4H fingerprint starts with correct version code for HTTP/2.0 (should be "20") + let filter_http2 = HttpFilter::new("starts_with(signal.ja4h, \"ge20\")")?; + let req_http2 = Builder::new() + .method("GET") + .uri("http://example.com/test") + .version(hyper::http::Version::HTTP_2) + .body(())?; + let (req_parts_http2, _) = req_http2.into_parts(); + let result_http2 = filter_http2.should_block_request_from_parts(&req_parts_http2, b"", peer_addr).await?; + assert!(result_http2.is_some(), "HTTP/2.0 request should generate JA4H starting with 'ge20'"); + + Ok(()) + } +} diff --git a/src/worker/agent_status.rs b/src/worker/agent_status.rs index 19b209b..92a07a0 100644 --- a/src/worker/agent_status.rs +++ b/src/worker/agent_status.rs @@ -1,57 +1,57 @@ -use std::time::Duration; - -use tokio::sync::watch; -use tokio::time::interval; - -use crate::agent_status::AgentStatusIdentity; -use crate::worker::log::{send_event, UnifiedEvent}; - -/// Agent status worker that sends register + heartbeat events -pub struct AgentStatusWorker { - identity: AgentStatusIdentity, - interval_secs: u64, -} - -impl AgentStatusWorker { - pub fn new(identity: AgentStatusIdentity, interval_secs: u64) -> Self { - Self { - identity, - interval_secs, - } - } -} - -impl super::Worker for AgentStatusWorker { - fn name(&self) -> &str { - "agent_status" - } - - fn run(&self, mut shutdown: watch::Receiver) -> tokio::task::JoinHandle<()> { - let identity = self.identity.clone(); - let interval_secs = self.interval_secs; - let worker_name = self.name().to_string(); - - tokio::spawn(async move { - let mut tick = interval(Duration::from_secs(interval_secs)); - tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - - // Initial register event - send_event(UnifiedEvent::AgentStatus(identity.to_event("running"))); - - loop { - tokio::select! { - _ = shutdown.changed() => { - if *shutdown.borrow() { - log::info!("[{}] Shutdown signal received, stopping agent status worker", worker_name); - break; - } - } - _ = tick.tick() => { - send_event(UnifiedEvent::AgentStatus(identity.to_event("running"))); - } - } - } - }) - } -} - +use std::time::Duration; + +use tokio::sync::watch; +use tokio::time::interval; + +use crate::agent_status::AgentStatusIdentity; +use crate::worker::log::{send_event, UnifiedEvent}; + +/// Agent status worker that sends register + heartbeat events +pub struct AgentStatusWorker { + identity: AgentStatusIdentity, + interval_secs: u64, +} + +impl AgentStatusWorker { + pub fn new(identity: AgentStatusIdentity, interval_secs: u64) -> Self { + Self { + identity, + interval_secs, + } + } +} + +impl super::Worker for AgentStatusWorker { + fn name(&self) -> &str { + "agent_status" + } + + fn run(&self, mut shutdown: watch::Receiver) -> tokio::task::JoinHandle<()> { + let identity = self.identity.clone(); + let interval_secs = self.interval_secs; + let worker_name = self.name().to_string(); + + tokio::spawn(async move { + let mut tick = interval(Duration::from_secs(interval_secs)); + tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + // Initial register event + send_event(UnifiedEvent::AgentStatus(identity.to_event("running"))); + + loop { + tokio::select! { + _ = shutdown.changed() => { + if *shutdown.borrow() { + log::info!("[{}] Shutdown signal received, stopping agent status worker", worker_name); + break; + } + } + _ = tick.tick() => { + send_event(UnifiedEvent::AgentStatus(identity.to_event("running"))); + } + } + } + }) + } +} + diff --git a/src/worker/certificate.rs b/src/worker/certificate.rs index 07d3b2f..0afbde9 100644 --- a/src/worker/certificate.rs +++ b/src/worker/certificate.rs @@ -1,1158 +1,1158 @@ -use anyhow::{Context, Result}; -use std::io::Write; -use std::sync::Arc; -use tokio::sync::watch; -use tokio::time::{interval, Duration}; - -use crate::redis::RedisManager; -use crate::utils::tls::{CertificateConfig, Certificates}; -use crate::worker::Worker; - -/// Calculate SHA256 hash of certificate files (fullchain + key) -fn calculate_local_hash(cert_path: &std::path::Path, key_path: &std::path::Path) -> Result { - use sha2::{Sha256, Digest}; - use std::io::Read; - - let mut hasher = Sha256::new(); - - // Read and hash certificate file - let mut cert_file = std::fs::File::open(cert_path) - .context(format!("Failed to open certificate file: {}", cert_path.display()))?; - let mut cert_data = Vec::new(); - cert_file.read_to_end(&mut cert_data) - .context(format!("Failed to read certificate file: {}", cert_path.display()))?; - hasher.update(&cert_data); - - // Read and hash key file - let mut key_file = std::fs::File::open(key_path) - .context(format!("Failed to open key file: {}", key_path.display()))?; - let mut key_data = Vec::new(); - key_file.read_to_end(&mut key_data) - .context(format!("Failed to read key file: {}", key_path.display()))?; - hasher.update(&key_data); - - Ok(format!("{:x}", hasher.finalize())) -} - -/// Normalize PEM certificate chain to ensure proper format -/// - Ensures newline between certificates (END CERTIFICATE and BEGIN CERTIFICATE) -/// - Ensures file ends with newline -fn normalize_pem_chain(chain: &str) -> String { - let mut normalized = chain.to_string(); - - // Ensure newline between END CERTIFICATE and BEGIN CERTIFICATE - // Replace "-----END CERTIFICATE----------BEGIN CERTIFICATE-----" with proper newline - normalized = normalized.replace("-----END CERTIFICATE----------BEGIN CERTIFICATE-----", - "-----END CERTIFICATE-----\n-----BEGIN CERTIFICATE-----"); - - // Ensure newline between END CERTIFICATE and BEGIN PRIVATE KEY (for key files) - normalized = normalized.replace("-----END CERTIFICATE----------BEGIN PRIVATE KEY-----", - "-----END CERTIFICATE-----\n-----BEGIN PRIVATE KEY-----"); - - // Ensure file ends with newline - if !normalized.ends_with('\n') { - normalized.push('\n'); - } - - normalized -} - -/// Global certificate store for Redis-loaded certificates -static CERTIFICATE_STORE: once_cell::sync::OnceCell>>>> = once_cell::sync::OnceCell::new(); - -/// Global in-memory cache for certificate hashes (domain -> SHA256 hash) -/// Using Arc> instead of MemoryCache to avoid lifetime issues -static CERTIFICATE_HASH_CACHE: once_cell::sync::OnceCell>>> = once_cell::sync::OnceCell::new(); - -/// Get the global certificate store -pub fn get_certificate_store() -> Arc>>> { - CERTIFICATE_STORE.get_or_init(|| Arc::new(tokio::sync::RwLock::new(None))).clone() -} - -/// Get the global certificate hash cache -/// Cache size: 1000 entries (should be enough for most deployments) -fn get_certificate_hash_cache() -> Arc>> { - CERTIFICATE_HASH_CACHE.get_or_init(|| { - Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())) - }).clone() -} - -/// Certificate worker that fetches SSL certificates from Redis -/// Uses upstreams.yaml as the source of truth for domains -pub struct CertificateWorker { - certificate_path: String, - upstreams_path: String, - refresh_interval_secs: u64, -} - -impl CertificateWorker { - pub fn new(certificate_path: String, upstreams_path: String, refresh_interval_secs: u64) -> Self { - Self { - certificate_path, - upstreams_path, - refresh_interval_secs, - } - } -} - -impl Worker for CertificateWorker { - fn name(&self) -> &str { - "certificate" - } - - fn run(&self, mut shutdown: watch::Receiver) -> tokio::task::JoinHandle<()> { - let certificate_path = self.certificate_path.clone(); - let upstreams_path = self.upstreams_path.clone(); - let refresh_interval_secs = self.refresh_interval_secs; - let worker_name = self.name().to_string(); - - tokio::spawn(async move { - // Store upstreams_path globally for ACME requests - set_upstreams_path(upstreams_path.clone()); - - // Initial fetch on startup - download all certificates immediately - log::info!("[{}] Starting certificate download from Redis on service startup...", worker_name); - match fetch_certificates_from_redis(&certificate_path, &upstreams_path).await { - Ok(_) => { - log::info!("[{}] Successfully downloaded all certificates from Redis on startup", worker_name); - } - Err(e) => { - log::warn!("[{}] Failed to fetch certificates from Redis on startup: {}", worker_name, e); - log::warn!("[{}] Will retry on next scheduled interval", worker_name); - } - } - - // Set up periodic refresh interval - let mut refresh_interval = interval(Duration::from_secs(refresh_interval_secs)); - // Skip the first tick since we already fetched on startup - refresh_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - - // Set up periodic expiration check (every 6 hours) - let mut expiration_check_interval = interval(Duration::from_secs(6 * 60 * 60)); // 6 hours - expiration_check_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - // Skip the first tick - we'll check after the first certificate fetch - let mut first_expiration_check = true; - - loop { - tokio::select! { - _ = shutdown.changed() => { - if *shutdown.borrow() { - break; - } - } - _ = refresh_interval.tick() => { - log::debug!("[{}] Periodic certificate refresh triggered", worker_name); - // Update upstreams_path in case it changed - set_upstreams_path(upstreams_path.clone()); - if let Err(e) = fetch_certificates_from_redis(&certificate_path, &upstreams_path).await { - log::warn!("[{}] Failed to fetch certificates from Redis: {}", worker_name, e); - } - } - _ = expiration_check_interval.tick() => { - if first_expiration_check { - first_expiration_check = false; - continue; // Skip first check, wait for next interval - } - log::info!("[{}] Periodic certificate expiration check triggered", worker_name); - // Update upstreams_path in case it changed - set_upstreams_path(upstreams_path.clone()); - if let Err(e) = check_and_renew_expiring_certificates(&upstreams_path).await { - log::warn!("[{}] Failed to check certificate expiration: {}", worker_name, e); - } - } - } - } - - log::info!("[{}] Certificate fetcher task stopped", worker_name); - }) - } -} - -/// Start a background task that periodically fetches SSL certificates from Redis -/// This is kept for backward compatibility -pub fn start_certificate_fetcher( - certificate_path: String, - upstreams_path: String, - refresh_interval_secs: u64, - shutdown: watch::Receiver, -) -> tokio::task::JoinHandle<()> { - let worker = CertificateWorker::new(certificate_path, upstreams_path, refresh_interval_secs); - worker.run(shutdown) -} - -/// Fetch domains from upstreams.yaml file (source of truth) -async fn fetch_domains_from_upstreams(upstreams_path: &str) -> Result> { - use serde_yaml; - use std::path::PathBuf; - - let path = PathBuf::from(upstreams_path); - - // Read and parse upstreams.yaml - let yaml_content = tokio::fs::read_to_string(&path) - .await - .with_context(|| format!("Failed to read upstreams file: {:?}", path))?; - - let parsed: crate::utils::structs::Config = serde_yaml::from_str(&yaml_content) - .with_context(|| format!("Failed to parse upstreams YAML: {:?}", path))?; - - let mut domains = Vec::new(); - - if let Some(upstreams) = &parsed.upstreams { - for (hostname, _host_config) in upstreams { - domains.push(hostname.clone()); - } - } - - log::info!("Found {} domain(s) in upstreams.yaml: {:?}", domains.len(), domains); - Ok(domains) -} - - -/// Fetch SSL certificates from Redis for domains listed in upstreams.yaml -async fn fetch_certificates_from_redis(certificate_path: &str, upstreams_path: &str) -> Result<()> { - let redis_manager = RedisManager::get() - .context("Redis manager not initialized")?; - - // Parse upstreams.yaml to get domains and their certificate mappings - use serde_yaml; - use std::path::PathBuf; - let path = PathBuf::from(upstreams_path); - let yaml_content = tokio::fs::read_to_string(&path) - .await - .with_context(|| format!("Failed to read upstreams file: {:?}", path))?; - let parsed: crate::utils::structs::Config = serde_yaml::from_str(&yaml_content) - .with_context(|| format!("Failed to parse upstreams YAML: {:?}", path))?; - - // Build mapping of domain -> certificate_name (or None if not specified) - // Only include domains that need certificates (have ACME config or ssl_enabled: true) - let mut domain_cert_map: Vec<(String, Option)> = Vec::new(); - if let Some(upstreams) = &parsed.upstreams { - for (hostname, host_config) in upstreams { - // Only process domains that need certificates - if !host_config.needs_certificate() { - log::debug!("Skipping certificate check for domain {} (no ACME config and ssl_enabled: false)", hostname); - continue; - } - let cert_name = host_config.certificate.clone(); - domain_cert_map.push((hostname.clone(), cert_name)); - } - } - - if domain_cert_map.is_empty() { - log::warn!("No domains found in upstreams.yaml, skipping certificate fetch"); - return Ok(()); - } - - log::info!("Checking certificates for {} domain(s) from Redis (will skip download if hashes match)", domain_cert_map.len()); - - let mut connection = redis_manager.get_connection(); - let mut certificate_configs = Vec::new(); - let mut skipped_count = 0; - let mut downloaded_count = 0; - let mut missing_count = 0; - let cert_dir = std::path::Path::new(certificate_path); - - // Create certificate directory if it doesn't exist - if !cert_dir.exists() { - log::info!("Creating certificate directory: {}", certificate_path); - std::fs::create_dir_all(cert_dir) - .context(format!("Failed to create certificate directory: {}", certificate_path))?; - log::info!("Certificate directory created: {}", certificate_path); - } else { - log::debug!("Certificate directory already exists: {}", certificate_path); - } - - for (domain, cert_name_opt) in &domain_cert_map { - // Use certificate name if specified, otherwise use domain name - let cert_name = cert_name_opt.as_ref().unwrap_or(domain); - // Normalize certificate name (remove wildcard prefix if present) - let normalized_cert_name = cert_name.strip_prefix("*.").unwrap_or(cert_name); - - // Check certificate hash from Redis before downloading - // Get prefix from RedisManager - let prefix = RedisManager::get() - .map(|rm| rm.get_prefix().to_string()) - .unwrap_or_else(|_| "ssl-storage".to_string()); - let hash_key = format!("{}:{}:metadata:certificate_hash", prefix, normalized_cert_name); - let remote_hash: Option = redis::cmd("GET") - .arg(&hash_key) - .query_async(&mut connection) - .await - .context(format!("Failed to get certificate hash for domain: {}", domain))?; - - // Check in-memory cache first for local hash - let hash_cache = get_certificate_hash_cache(); - - // Get file paths first - use normalized certificate name for file naming - // This matches the filename format used by ACME (which uses normalized domain) - // ACME saves as: {sanitized_domain}.crt where domain is normalized (wildcard prefix removed) - let sanitized_cert_name = normalized_cert_name.replace('.', "_").replace('*', "_"); - let cert_path = cert_dir.join(format!("{}.crt", sanitized_cert_name)); - let key_path = cert_dir.join(format!("{}.key", sanitized_cert_name)); - - // Check local certificates first before checking Redis - // This allows using local certificates even if Redis is unavailable or certificate not in Redis yet - let local_hash = if cert_path.exists() && key_path.exists() { - // Local files exist - check cache first, then calculate if needed - let cached_hash = { - let cache = hash_cache.read().await; - cache.get(cert_name).cloned() - }; - - if let Some(cached_hash) = cached_hash { - Some(cached_hash) - } else if let Ok(calculated_hash) = calculate_local_hash(&cert_path, &key_path) { - // Calculate from files and cache it - let mut cache = hash_cache.write().await; - cache.insert(cert_name.to_string(), calculated_hash.clone()); - Some(calculated_hash) - } else { - None - } - } else { - // Files don't exist - clear cache - let mut cache = hash_cache.write().await; - cache.remove(cert_name); - None - }; - - // Determine if we need to download - // Priority: Use local certificate if available and valid, otherwise check Redis - let should_download = if let Some(local) = &local_hash { - // Local certificate exists - check if it matches Redis (if Redis hash is available) - if let Some(remote) = &remote_hash { - if remote == local { - // Hashes match - use local certificate - log::debug!("Certificate hash matches for domain: {} (hash: {}), using local certificate", domain, remote); - // Add existing certificate to config without re-downloading - certificate_configs.push(CertificateConfig { - cert_path: cert_path.to_string_lossy().to_string(), - key_path: key_path.to_string_lossy().to_string(), - }); - skipped_count += 1; - log::debug!("Added existing certificate config for domain: {} -> cert: {}, key: {}", - domain, cert_path.display(), key_path.display()); - false // Don't download - } else { - log::debug!("Certificate hash mismatch for domain: {} (remote: {}, local: {}), downloading new certificate from Redis", domain, remote, local); - true // Download - hash changed in Redis - } - } else { - // No Redis hash, but we have local certificate - use it - log::debug!("No Redis hash found for domain: {}, but local certificate exists, using local certificate", domain); - // Add existing certificate to config without re-downloading - certificate_configs.push(CertificateConfig { - cert_path: cert_path.to_string_lossy().to_string(), - key_path: key_path.to_string_lossy().to_string(), - }); - skipped_count += 1; - log::debug!("Added existing local certificate config for domain: {} -> cert: {}, key: {}", - domain, cert_path.display(), key_path.display()); - false // Don't download - use local certificate - } - } else if remote_hash.is_some() { - // No local certificate, but Redis has one - download it - log::debug!("No local certificate found for domain: {}, but Redis hash exists, downloading from Redis", domain); - true // Download from Redis - } else { - // No local certificate and no Redis hash - check if certificate exists in Redis - log::debug!("No local certificate and no Redis hash for domain: {}, will check if certificate exists in Redis", domain); - true // Check Redis for certificate - }; - - // Skip download if not needed - if !should_download { - continue; - } - - // Fetch fullchain and private key from Redis - // Get prefix from RedisManager - let prefix = RedisManager::get() - .map(|rm| rm.get_prefix().to_string()) - .unwrap_or_else(|_| "ssl-storage".to_string()); - // Redis stores certificates with keys: - // - {prefix}:{cert_name}:live:fullchain - // - {prefix}:{cert_name}:live:privkey - let fullchain_key = format!("{}:{}:live:fullchain", prefix, normalized_cert_name); - let privkey_key = format!("{}:{}:live:privkey", prefix, normalized_cert_name); - - log::info!("Fetching certificate for domain: {} (using cert: {}, normalized: {}, prefix: {})", domain, cert_name, normalized_cert_name, prefix); - log::info!("Fullchain key: '{}', Privkey key: '{}'", fullchain_key, privkey_key); - - let fullchain: Option> = redis::cmd("GET") - .arg(&fullchain_key) - .query_async(&mut connection) - .await - .context(format!("Failed to get fullchain for domain: {}", domain))?; - - let privkey: Option> = redis::cmd("GET") - .arg(&privkey_key) - .query_async(&mut connection) - .await - .context(format!("Failed to get private key for domain: {}", domain))?; - - log::info!("Redis GET results for domain {}: fullchain={}, privkey={}", - domain, - if fullchain.is_some() { "Some" } else { "None" }, - if privkey.is_some() { "Some" } else { "None" } - ); - - match (fullchain, privkey) { - (Some(fullchain_bytes), Some(privkey_bytes)) => { - // Validate PEM format - let fullchain_str = match String::from_utf8(fullchain_bytes.clone()) { - Ok(s) => s, - Err(_) => { - log::warn!("Fullchain for domain {} is not valid UTF-8", domain); - continue; - } - }; - - let privkey_str = match String::from_utf8(privkey_bytes.clone()) { - Ok(s) => s, - Err(_) => { - log::warn!("Private key for domain {} is not valid UTF-8", domain); - continue; - } - }; - - if !fullchain_str.contains("-----BEGIN CERTIFICATE-----") { - log::warn!("Fullchain for domain {} does not appear to be in PEM format", domain); - continue; - } - if !privkey_str.contains("-----BEGIN") { - log::warn!("Private key for domain {} does not appear to be in PEM format", domain); - continue; - } - - // Write certificates to certificate directory - // Use sanitized certificate name for file names (already set above) - // cert_path and key_path are already set using cert_name - - log::debug!("Writing certificate to: {} and key to: {}", cert_path.display(), key_path.display()); - - // Write fullchain to file - // Normalize the fullchain to ensure proper PEM format: - // - Ensure newline between certificates (END CERTIFICATE and BEGIN CERTIFICATE) - // - Ensure file ends with newline - let normalized_fullchain = normalize_pem_chain(&fullchain_str); - let mut cert_file = std::fs::File::create(&cert_path) - .context(format!("Failed to create certificate file for domain: {} at path: {}", domain, cert_path.display()))?; - cert_file.write_all(normalized_fullchain.as_bytes()) - .context(format!("Failed to write certificate file for domain: {} to path: {}", domain, cert_path.display()))?; - cert_file.sync_all() - .context(format!("Failed to sync certificate file for domain: {} at path: {}", domain, cert_path.display()))?; - - // Write private key to file - // Normalize the key to ensure proper PEM format - let normalized_key = normalize_pem_chain(&privkey_str); - let mut key_file = std::fs::File::create(&key_path) - .context(format!("Failed to create key file for domain: {} at path: {}", domain, key_path.display()))?; - key_file.write_all(normalized_key.as_bytes()) - .context(format!("Failed to write key file for domain: {} to path: {}", domain, key_path.display()))?; - key_file.sync_all() - .context(format!("Failed to sync key file for domain: {} at path: {}", domain, key_path.display()))?; - - // Calculate hash from raw bytes (before normalization) to match Redis hash calculation - // Redis calculates hash from: fullchain (raw bytes) + key (raw bytes) - // We need to match this exactly, not from normalized files - use sha2::{Sha256, Digest}; - let mut hasher = Sha256::new(); - hasher.update(&fullchain_bytes); - hasher.update(&privkey_bytes); - let calculated_hash = format!("{:x}", hasher.finalize()); - - downloaded_count += 1; - log::info!("Successfully downloaded and saved certificate for domain: {} to {}", domain, cert_path.display()); - - // Verify files were written correctly - if !cert_path.exists() { - log::warn!("Certificate file does not exist after write: {}", cert_path.display()); - continue; - } - if !key_path.exists() { - log::warn!("Key file does not exist after write: {}", key_path.display()); - continue; - } - - // Store local hash in memory cache after successful download - // Use the hash calculated from raw bytes (matching Redis calculation) - let cert_key = cert_name.to_string(); - let hash_cache = get_certificate_hash_cache(); - let mut cache = hash_cache.write().await; - cache.insert(cert_key, calculated_hash.clone()); - log::debug!("Stored local hash in memory cache for domain: {} -> {} (calculated from raw bytes, matching Redis)", domain, calculated_hash); - - // Verify hash matches Redis hash - if let Some(remote_hash) = &remote_hash { - if calculated_hash != *remote_hash { - log::warn!("Hash mismatch after download for domain {}: calculated={}, redis={}. This should not happen!", - domain, calculated_hash, remote_hash); - } else { - log::debug!("Hash verified: calculated hash matches Redis hash for domain: {}", domain); - } - } - - // Create certificate config entry - certificate_configs.push(CertificateConfig { - cert_path: cert_path.to_string_lossy().to_string(), - key_path: key_path.to_string_lossy().to_string(), - }); - log::debug!("Added certificate config for domain: {} -> cert: {}, key: {}", - domain, cert_path.display(), key_path.display()); - } - (Some(_), None) => { - missing_count += 1; - log::warn!("Certificate fullchain found but private key missing in Redis for domain: {} (cert: {}, key: {})", domain, cert_name, privkey_key); - // Request certificate from ACME using certificate name (with wildcard if configured) - if let Err(e) = request_certificate_from_acme(domain, normalized_cert_name, &certificate_path).await { - log::warn!("Failed to request certificate from ACME for domain {}: {}", domain, e); - } else { - log::debug!("Successfully requested certificate from ACME for domain: {} (certificate: {})", domain, cert_name); - } - } - (None, Some(_)) => { - missing_count += 1; - log::warn!("Certificate private key found but fullchain missing in Redis for domain: {} (cert: {}, key: {})", domain, cert_name, fullchain_key); - // Request certificate from ACME using certificate name (with wildcard if configured) - if let Err(e) = request_certificate_from_acme(domain, normalized_cert_name, &certificate_path).await { - log::warn!("Failed to request certificate from ACME for domain {}: {}", domain, e); - } else { - log::debug!("Successfully requested certificate from ACME for domain: {} (certificate: {})", domain, cert_name); - } - } - (None, None) => { - missing_count += 1; - log::warn!("Certificate not found in Redis for domain: {} (cert: {}, checked keys: fullchain='{}', privkey='{}')", - domain, cert_name, fullchain_key, privkey_key); - - // Try to list matching keys to help debug - let pattern = format!("{}:{}:*", prefix, normalized_cert_name); - let keys_result: Result, _> = redis::cmd("KEYS") - .arg(&pattern) - .query_async(&mut connection) - .await; - match keys_result { - Ok(keys) => { - if !keys.is_empty() { - log::debug!("Found {} matching keys for pattern '{}': {:?}", keys.len(), pattern, keys); - } else { - log::warn!("No keys found matching pattern '{}'", pattern); - } - } - Err(e) => { - log::debug!("Failed to list keys with pattern '{}': {}", pattern, e); - } - } - - // Request certificate from ACME server if enabled - // Use certificate name (with wildcard if configured) instead of hostname - if let Err(e) = request_certificate_from_acme(domain, normalized_cert_name, &certificate_path).await { - log::warn!("Failed to request certificate from ACME for domain {}: {}", domain, e); - } else { - log::debug!("Successfully requested certificate from ACME for domain: {} (certificate: {})", domain, cert_name); - } - } - } - } - - // Log summary - if skipped_count > 0 { - log::debug!("Skipped {} certificate(s) due to hash matches (using existing files)", skipped_count); - } - if downloaded_count > 0 { - log::info!("Downloaded {} new/updated certificate(s) from Redis", downloaded_count); - } - if missing_count > 0 { - log::warn!("{} certificate(s) not found in Redis", missing_count); - } - - if !certificate_configs.is_empty() { - log::debug!("Successfully processed {} certificate(s) ({} downloaded, {} skipped)", - certificate_configs.len(), downloaded_count, skipped_count); - log::debug!("Certificate configs to load: {:?}", - certificate_configs.iter().map(|c| format!("cert: {}, key: {}", c.cert_path, c.key_path)).collect::>()); - - // Update the certificate store - // Use "medium" as default TLS grade (can be made configurable) - // Default certificate is None for worker (can be made configurable later) - match Certificates::new(&certificate_configs, "medium", None) { - Some(certificates) => { - let store = get_certificate_store(); - let mut guard = store.write().await; - *guard = Some(Arc::new(certificates)); - log::debug!("Updated certificate store with {} certificates", certificate_configs.len()); - } - None => { - log::error!("Failed to create Certificates object from fetched configs. This usually means one or more certificate files are invalid or cannot be loaded."); - log::error!("Attempted to load {} certificate configs", certificate_configs.len()); - for config in &certificate_configs { - log::error!(" - cert: {}, key: {}", config.cert_path, config.key_path); - } - } - } - } else { - log::warn!("No certificates were processed. Check if certificates exist in Redis for the domains listed in upstreams.yaml, or if all certificates were skipped due to hash matches but files are missing"); - } - - Ok(()) -} - -/// Global ACME config store -static ACME_CONFIG: once_cell::sync::OnceCell>>> = once_cell::sync::OnceCell::new(); - -/// Global upstreams path store -static UPSTREAMS_PATH: once_cell::sync::OnceCell>>> = once_cell::sync::OnceCell::new(); - -/// Set the global ACME config (called from main.rs) -/// Can be called multiple times to update the config at runtime -pub fn set_acme_config(config: crate::cli::AcmeConfig) { - let store = ACME_CONFIG.get_or_init(|| Arc::new(std::sync::RwLock::new(None))); - let mut guard = store.write().unwrap(); - let was_development = guard.as_ref().map(|c| c.development).unwrap_or(false); - let is_development = config.development; - - // Extract values for logging before moving config - let enabled = config.enabled; - let port = config.port; - - // Log if development mode changes - if was_development != is_development { - log::warn!("ACME development mode changed: {} -> {}. Existing certificates may have been issued with the previous mode. Consider clearing certificates and restarting.", was_development, is_development); - } - - *guard = Some(config); - log::info!("ACME config updated: enabled={}, development={}, port={}", enabled, is_development, port); -} - -/// Set the global upstreams path (called from certificate worker) -fn set_upstreams_path(path: String) { - let store = UPSTREAMS_PATH.get_or_init(|| Arc::new(std::sync::RwLock::new(None))); - let mut guard = store.write().unwrap(); - *guard = Some(path); -} - -/// Get the global ACME config -pub async fn get_acme_config() -> Option { - let store = ACME_CONFIG.get()?; - let guard = tokio::task::spawn_blocking({ - let store = Arc::clone(store); - move || store.read().unwrap().clone() - }).await.ok()?; - guard -} - -/// Get the global upstreams path -async fn get_upstreams_path() -> Option { - let store = UPSTREAMS_PATH.get()?; - let guard = tokio::task::spawn_blocking({ - let store = Arc::clone(store); - move || store.read().unwrap().clone() - }).await.ok()?; - guard -} - -/// Request a certificate from ACME server for a domain -pub async fn request_certificate_from_acme( - domain: &str, - normalized_domain: &str, - _certificate_path: &str, -) -> Result<()> { - use crate::acme::{Config, ConfigOpts, request_cert}; - use std::path::PathBuf; - - // Check if ACME is enabled - let acme_config = match get_acme_config().await { - Some(config) if config.enabled => { - // Log which ACME server will be used - if config.development { - log::warn!("ACME development mode is ENABLED - certificates will be issued from Let's Encrypt STAGING server (not trusted by browsers)"); - } else { - log::info!("ACME development mode is DISABLED - certificates will be issued from Let's Encrypt PRODUCTION server"); - } - config - }, - Some(_) => { - log::debug!("ACME is disabled, skipping certificate request for domain: {}", domain); - return Ok(()); - } - None => { - log::debug!("ACME config not available, skipping certificate request for domain: {}", domain); - return Ok(()); - } - }; - - // Get email - use from config or default - let email = acme_config.email - .unwrap_or_else(|| "admin@example.com".to_string()); - - // Get Redis URL from RedisManager if available - let redis_url = crate::redis::RedisManager::get() - .ok() - .and_then(|_| { - // Use ACME config Redis URL, or try to get from RedisManager - acme_config.redis_url.clone() - .or_else(|| std::env::var("REDIS_URL").ok()) - }); - - // Read challenge type from upstreams.yaml - // Get upstreams path from global store (set by certificate worker) or use default - let upstreams_path = get_upstreams_path().await - .unwrap_or_else(|| "/root/synapse/upstreams.yaml".to_string()); - - // Determine the domain to request from ACME - // If normalized_domain is different from domain, it means a certificate name was specified - // In that case, we should request the certificate for the certificate domain, not the hostname - let (acme_domain, use_dns, domain_email, is_wildcard) = { - // Try to read challenge type and certificate config from upstreams.yaml - if let Ok(yaml_content) = tokio::fs::read_to_string(&upstreams_path).await { - if let Ok(parsed) = serde_yaml::from_str::(&yaml_content) { - if let Some(upstreams) = &parsed.upstreams { - if let Some(host_config) = upstreams.get(domain) { - // Check if a certificate name is specified (different from domain) - let cert_name_opt = host_config.certificate.as_ref(); - let acme_wildcard = host_config.acme.as_ref() - .map(|a| a.wildcard) - .unwrap_or(false); - - // Determine the domain to request from ACME - // If certificate name is specified and wildcard is true, request *.certificate_name - // Otherwise, if certificate name is specified, use it as-is - // If no certificate name, use the hostname domain - let requested_domain = if let Some(cert_name) = cert_name_opt { - // Certificate name is specified - use it for ACME request - if acme_wildcard && !cert_name.starts_with("*.") { - // Wildcard is set in config - request *.certificate_name - format!("*.{}", cert_name) - } else if cert_name.starts_with("*.") { - // Certificate name already has wildcard prefix - cert_name.clone() - } else { - // Certificate name without wildcard - cert_name.clone() - } - } else { - // No certificate name specified - use hostname domain - if acme_wildcard && !domain.starts_with("*.") { - // Wildcard is set but domain doesn't have *. prefix - format!("*.{}", normalized_domain) - } else { - domain.to_string() - } - }; - - // Get challenge type from ACME config in upstreams - let challenge_type = if let Some(acme_cfg) = &host_config.acme { - acme_cfg.challenge_type.clone() - } else { - // Auto-detect: DNS-01 for wildcard, HTTP-01 otherwise - if requested_domain.starts_with("*.") || acme_wildcard { - "dns-01".to_string() - } else { - "http-01".to_string() - } - }; - - let use_dns = challenge_type == "dns-01"; - let domain_email = host_config.acme.as_ref() - .and_then(|a| a.email.clone()) - .or_else(|| Some(email.clone())); - - let is_wildcard = requested_domain.starts_with("*.") || acme_wildcard; - - log::info!("ACME request: hostname={}, certificate={:?}, requested_domain={}, wildcard={}, challenge={}", - domain, cert_name_opt, requested_domain, is_wildcard, challenge_type); - (requested_domain, use_dns, domain_email, is_wildcard) - } else { - // Domain not found in upstreams, auto-detect - // If normalized_domain != domain, it means a certificate name was passed - let is_wildcard = domain.starts_with("*.") || normalized_domain != domain; - let requested_domain = if domain.starts_with("*.") { - domain.to_string() - } else if normalized_domain != domain { - // Certificate name was specified - assume wildcard - format!("*.{}", normalized_domain) - } else { - domain.to_string() - }; - let use_dns = is_wildcard; // Use DNS-01 for wildcard certificates - log::info!("Domain {} not found in upstreams.yaml, auto-detecting (requested: {}, wildcard: {}, dns: {})", - domain, requested_domain, is_wildcard, use_dns); - (requested_domain, use_dns, Some(email.clone()), is_wildcard) - } - } else { - // No upstreams, auto-detect - let is_wildcard = domain.starts_with("*.") || normalized_domain != domain; - let requested_domain = if domain.starts_with("*.") { - domain.to_string() - } else if normalized_domain != domain { - format!("*.{}", normalized_domain) - } else { - domain.to_string() - }; - let use_dns = is_wildcard; - (requested_domain, use_dns, Some(email.clone()), is_wildcard) - } - } else { - // Failed to parse, auto-detect - let is_wildcard = domain.starts_with("*.") || normalized_domain != domain; - let requested_domain = if domain.starts_with("*.") { - domain.to_string() - } else if normalized_domain != domain { - format!("*.{}", normalized_domain) - } else { - domain.to_string() - }; - let use_dns = is_wildcard; - (requested_domain, use_dns, Some(email.clone()), is_wildcard) - } - } else { - // Failed to read, auto-detect - let is_wildcard = domain.starts_with("*.") || normalized_domain != domain; - let requested_domain = if domain.starts_with("*.") { - domain.to_string() - } else if normalized_domain != domain { - format!("*.{}", normalized_domain) - } else { - domain.to_string() - }; - let use_dns = is_wildcard; - (requested_domain, use_dns, Some(email.clone()), is_wildcard) - } - }; - - // Create domain config for ACME - let mut domain_storage_path = PathBuf::from(&acme_config.storage_path); - domain_storage_path.push(normalized_domain); - - let mut cert_path = domain_storage_path.clone(); - cert_path.push("cert.pem"); - let mut key_path = domain_storage_path.clone(); - key_path.push("key.pem"); - let static_path = domain_storage_path.clone(); - - // Get Redis SSL config if available - let redis_ssl = crate::redis::RedisManager::get() - .ok() - .and_then(|_| { - // Try to get SSL config from global config if available - // For now, we'll use None and let it use defaults - None - }); - - let acme_config_internal = Config { - https_path: domain_storage_path, - cert_path, - key_path, - static_path, - opts: ConfigOpts { - ip: "127.0.0.1".to_string(), - port: acme_config.port, - domain: acme_domain.clone(), - email: domain_email, - https_dns: use_dns, - development: acme_config.development, - dns_lookup_max_attempts: Some(100), - dns_lookup_delay_seconds: Some(10), - storage_type: { - // Always use Redis (storage_type option is kept for compatibility but always uses Redis) - Some("redis".to_string()) - }, - redis_url, - lock_ttl_seconds: Some(900), - redis_ssl, - challenge_max_ttl_seconds: Some(3600), - }, - }; - - // Request certificate from ACME - log::info!("Requesting certificate from ACME: hostname={} -> certificate_domain={} (wildcard: {}, dns: {})", - domain, acme_domain, is_wildcard, use_dns); - - request_cert(&acme_config_internal).await - .context(format!("Failed to request certificate from ACME for hostname: {} (requested domain: {})", domain, acme_domain))?; - - log::info!("Certificate requested successfully from ACME: hostname={}, certificate_domain={}. It will be available in Redis after processing.", domain, acme_domain); - - // After requesting, the certificate should be in Redis (if using Redis storage) - // The next refresh cycle will pick it up automatically - - Ok(()) -} - -/// Check certificates for expiration and renew if expiring within 60 days -async fn check_and_renew_expiring_certificates(upstreams_path: &str) -> Result<()> { - use x509_parser::prelude::*; - use x509_parser::nom::Err as NomErr; - use rustls_pemfile::read_one; - use std::io::BufReader; - - // Get the list of domains from upstreams.yaml - let domains = fetch_domains_from_upstreams(upstreams_path).await?; - - if domains.is_empty() { - log::debug!("No domains found in upstreams.yaml, skipping expiration check"); - return Ok(()); - } - - log::info!("Checking certificate expiration for {} domain(s)", domains.len()); - - let redis_manager = RedisManager::get() - .context("Redis manager not initialized")?; - - let mut connection = redis_manager.get_connection(); - let mut renewed_count = 0; - let mut checked_count = 0; - - for domain in &domains { - let normalized_domain = domain.strip_prefix("*.").unwrap_or(domain); - - // Check if certificate exists in Redis - // Get prefix from RedisManager - let prefix = RedisManager::get() - .map(|rm| rm.get_prefix().to_string()) - .unwrap_or_else(|_| "ssl-storage".to_string()); - let fullchain_key = format!("{}:{}:live:fullchain", prefix, normalized_domain); - let fullchain: Option> = redis::cmd("GET") - .arg(&fullchain_key) - .query_async(&mut connection) - .await - .context(format!("Failed to get fullchain for domain: {}", domain))?; - - let fullchain_bytes = match fullchain { - Some(bytes) => bytes, - None => { - log::debug!("Certificate not found in Redis for domain: {}, skipping expiration check", domain); - continue; - } - }; - - // Parse the certificate to get expiration date - let fullchain_str = match String::from_utf8(fullchain_bytes.clone()) { - Ok(s) => s, - Err(_) => { - log::warn!("Fullchain for domain {} is not valid UTF-8, skipping expiration check", domain); - continue; - } - }; - - // Parse PEM to get the first certificate (domain cert) - let mut reader = BufReader::new(fullchain_str.as_bytes()); - let cert_der = match read_one(&mut reader) { - Ok(Some(rustls_pemfile::Item::X509Certificate(cert))) => cert, - Ok(_) => { - log::warn!("No X509 certificate found in fullchain for domain: {}", domain); - continue; - } - Err(e) => { - log::warn!("Failed to parse certificate for domain {}: {:?}", domain, e); - continue; - } - }; - - // Parse the DER certificate - let (_, x509_cert) = match X509Certificate::from_der(&cert_der) { - Ok(cert) => cert, - Err(NomErr::Error(e)) | Err(NomErr::Failure(e)) => { - log::warn!("Failed to parse X509 certificate for domain {}: {:?}", domain, e); - continue; - } - Err(_) => { - log::warn!("Unknown error parsing X509 certificate for domain: {}", domain); - continue; - } - }; - - // Get expiration date - let validity = x509_cert.validity(); - let not_after_offset = validity.not_after.to_datetime(); - let now = chrono::Utc::now(); - - // Convert OffsetDateTime to chrono::DateTime - let not_after = chrono::DateTime::::from_timestamp( - not_after_offset.unix_timestamp(), - 0 - ).unwrap_or_else(|| { - log::warn!("Failed to convert certificate expiration date for domain: {}", domain); - now + chrono::Duration::days(90) // Fallback to 90 days from now - }); - - // Calculate days until expiration - let expires_in = not_after - now; - let days_until_expiration = expires_in.num_days(); - - checked_count += 1; - - log::debug!("Certificate for domain {} expires in {} days (expires at: {})", - domain, days_until_expiration, not_after); - - // Check if certificate expires in less than 60 days - if days_until_expiration < 60 { - log::info!("Certificate for domain {} expires in {} days (< 60 days), starting renewal process", - domain, days_until_expiration); - - // Request renewal from ACME - let certificate_path = "/tmp/synapse-certs"; // Placeholder, will be stored in Redis - if let Err(e) = request_certificate_from_acme(domain, normalized_domain, certificate_path).await { - log::warn!("Failed to renew certificate for domain {}: {}", domain, e); - } else { - log::info!("Successfully initiated certificate renewal for domain: {}", domain); - renewed_count += 1; - } - } else { - log::debug!("Certificate for domain {} is still valid (expires in {} days)", - domain, days_until_expiration); - } - } - - log::info!("Certificate expiration check completed: {} checked, {} renewed", checked_count, renewed_count); - - Ok(()) -} - -/// Clear a specific certificate from both local filesystem and Redis -/// certificate_name: The certificate name (e.g., "kapnative.developnet.hu" or "*.kapnative.developnet.hu") -/// certificate_path: The path where certificates are stored locally -pub async fn clear_certificate(certificate_name: &str, certificate_path: &str) -> Result<()> { - // Normalize certificate name (remove wildcard prefix if present) - let normalized_cert_name = certificate_name.strip_prefix("*.").unwrap_or(certificate_name); - - // Sanitize certificate name for file naming (matches the format used when saving) - let sanitized_cert_name = normalized_cert_name.replace('.', "_").replace('*', "_"); - - let cert_dir = std::path::Path::new(certificate_path); - let cert_path = cert_dir.join(format!("{}.crt", sanitized_cert_name)); - let key_path = cert_dir.join(format!("{}.key", sanitized_cert_name)); - - log::info!("Clearing certificate: {} (normalized: {}, sanitized: {})", - certificate_name, normalized_cert_name, sanitized_cert_name); - - // Delete local certificate files - let mut deleted_local = false; - if cert_path.exists() { - match std::fs::remove_file(&cert_path) { - Ok(_) => { - log::info!("Deleted local certificate file: {}", cert_path.display()); - deleted_local = true; - } - Err(e) => { - log::warn!("Failed to delete local certificate file {}: {}", cert_path.display(), e); - } - } - } else { - log::debug!("Local certificate file does not exist: {}", cert_path.display()); - } - - if key_path.exists() { - match std::fs::remove_file(&key_path) { - Ok(_) => { - log::info!("Deleted local key file: {}", key_path.display()); - deleted_local = true; - } - Err(e) => { - log::warn!("Failed to delete local key file {}: {}", key_path.display(), e); - } - } - } else { - log::debug!("Local key file does not exist: {}", key_path.display()); - } - - // Clear from in-memory hash cache - let hash_cache = get_certificate_hash_cache(); - { - let mut cache = hash_cache.write().await; - cache.remove(certificate_name); - log::debug!("Removed certificate hash from in-memory cache: {}", certificate_name); - } - - // Delete from Redis - let redis_manager = match RedisManager::get() { - Ok(rm) => rm, - Err(e) => { - log::warn!("Redis manager not initialized, skipping Redis deletion: {}", e); - if deleted_local { - log::info!("Certificate cleared from local filesystem only (Redis not available)"); - } - return Ok(()); - } - }; - - let mut connection = redis_manager.get_connection(); - let prefix = redis_manager.get_prefix(); - - // Delete all certificate-related keys from Redis - // This includes all live certificates, metadata, and failure tracking - let keys_to_delete = vec![ - // Live certificate files - format!("{}:{}:live:fullchain", prefix, normalized_cert_name), - format!("{}:{}:live:privkey", prefix, normalized_cert_name), - format!("{}:{}:live:cert", prefix, normalized_cert_name), - format!("{}:{}:live:chain", prefix, normalized_cert_name), - // Metadata keys - format!("{}:{}:metadata:certificate_hash", prefix, normalized_cert_name), - format!("{}:{}:metadata:created_at", prefix, normalized_cert_name), - format!("{}:{}:metadata:cert_failure", prefix, normalized_cert_name), - format!("{}:{}:metadata:cert_failure_count", prefix, normalized_cert_name), - ]; - - let mut deleted_redis = 0; - for key in &keys_to_delete { - match redis::cmd("DEL").arg(key).query_async::(&mut connection).await { - Ok(count) => { - if count > 0 { - log::info!("Deleted Redis key: {}", key); - deleted_redis += count; - } else { - log::debug!("Redis key does not exist: {}", key); - } - } - Err(e) => { - log::warn!("Failed to delete Redis key {}: {}", key, e); - } - } - } - - // Also try to delete any archived certificates, challenge keys, lock keys, etc. (if they exist) - let patterns_to_clean = vec![ - format!("{}:{}:archive:*", prefix, normalized_cert_name), - format!("{}:{}:challenge:*", prefix, normalized_cert_name), - format!("{}:{}:dns-challenge", prefix, normalized_cert_name), - format!("{}:{}:lock", prefix, normalized_cert_name), - ]; - - for pattern in &patterns_to_clean { - match redis::cmd("KEYS").arg(pattern).query_async::>(&mut connection).await { - Ok(keys) => { - if !keys.is_empty() { - log::info!("Found {} key(s) matching pattern '{}', deleting...", keys.len(), pattern); - for key in &keys { - match redis::cmd("DEL").arg(key).query_async::(&mut connection).await { - Ok(count) => { - if count > 0 { - log::info!("Deleted Redis key: {}", key); - deleted_redis += count; - } - } - Err(e) => { - log::warn!("Failed to delete Redis key {}: {}", key, e); - } - } - } - } - } - Err(e) => { - log::debug!("Failed to list keys matching pattern '{}' (this is OK if none exist): {}", pattern, e); - } - } - } - - if deleted_local || deleted_redis > 0 { - log::info!("Successfully cleared certificate '{}': deleted {} local file(s), {} Redis key(s)", - certificate_name, if deleted_local { 2 } else { 0 }, deleted_redis); - } else { - log::warn!("Certificate '{}' not found in local filesystem or Redis", certificate_name); - } - - Ok(()) -} - +use anyhow::{Context, Result}; +use std::io::Write; +use std::sync::Arc; +use tokio::sync::watch; +use tokio::time::{interval, Duration}; + +use crate::redis::RedisManager; +use crate::utils::tls::{CertificateConfig, Certificates}; +use crate::worker::Worker; + +/// Calculate SHA256 hash of certificate files (fullchain + key) +fn calculate_local_hash(cert_path: &std::path::Path, key_path: &std::path::Path) -> Result { + use sha2::{Sha256, Digest}; + use std::io::Read; + + let mut hasher = Sha256::new(); + + // Read and hash certificate file + let mut cert_file = std::fs::File::open(cert_path) + .context(format!("Failed to open certificate file: {}", cert_path.display()))?; + let mut cert_data = Vec::new(); + cert_file.read_to_end(&mut cert_data) + .context(format!("Failed to read certificate file: {}", cert_path.display()))?; + hasher.update(&cert_data); + + // Read and hash key file + let mut key_file = std::fs::File::open(key_path) + .context(format!("Failed to open key file: {}", key_path.display()))?; + let mut key_data = Vec::new(); + key_file.read_to_end(&mut key_data) + .context(format!("Failed to read key file: {}", key_path.display()))?; + hasher.update(&key_data); + + Ok(format!("{:x}", hasher.finalize())) +} + +/// Normalize PEM certificate chain to ensure proper format +/// - Ensures newline between certificates (END CERTIFICATE and BEGIN CERTIFICATE) +/// - Ensures file ends with newline +fn normalize_pem_chain(chain: &str) -> String { + let mut normalized = chain.to_string(); + + // Ensure newline between END CERTIFICATE and BEGIN CERTIFICATE + // Replace "-----END CERTIFICATE----------BEGIN CERTIFICATE-----" with proper newline + normalized = normalized.replace("-----END CERTIFICATE----------BEGIN CERTIFICATE-----", + "-----END CERTIFICATE-----\n-----BEGIN CERTIFICATE-----"); + + // Ensure newline between END CERTIFICATE and BEGIN PRIVATE KEY (for key files) + normalized = normalized.replace("-----END CERTIFICATE----------BEGIN PRIVATE KEY-----", + "-----END CERTIFICATE-----\n-----BEGIN PRIVATE KEY-----"); + + // Ensure file ends with newline + if !normalized.ends_with('\n') { + normalized.push('\n'); + } + + normalized +} + +/// Global certificate store for Redis-loaded certificates +static CERTIFICATE_STORE: once_cell::sync::OnceCell>>>> = once_cell::sync::OnceCell::new(); + +/// Global in-memory cache for certificate hashes (domain -> SHA256 hash) +/// Using Arc> instead of MemoryCache to avoid lifetime issues +static CERTIFICATE_HASH_CACHE: once_cell::sync::OnceCell>>> = once_cell::sync::OnceCell::new(); + +/// Get the global certificate store +pub fn get_certificate_store() -> Arc>>> { + CERTIFICATE_STORE.get_or_init(|| Arc::new(tokio::sync::RwLock::new(None))).clone() +} + +/// Get the global certificate hash cache +/// Cache size: 1000 entries (should be enough for most deployments) +fn get_certificate_hash_cache() -> Arc>> { + CERTIFICATE_HASH_CACHE.get_or_init(|| { + Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())) + }).clone() +} + +/// Certificate worker that fetches SSL certificates from Redis +/// Uses upstreams.yaml as the source of truth for domains +pub struct CertificateWorker { + certificate_path: String, + upstreams_path: String, + refresh_interval_secs: u64, +} + +impl CertificateWorker { + pub fn new(certificate_path: String, upstreams_path: String, refresh_interval_secs: u64) -> Self { + Self { + certificate_path, + upstreams_path, + refresh_interval_secs, + } + } +} + +impl Worker for CertificateWorker { + fn name(&self) -> &str { + "certificate" + } + + fn run(&self, mut shutdown: watch::Receiver) -> tokio::task::JoinHandle<()> { + let certificate_path = self.certificate_path.clone(); + let upstreams_path = self.upstreams_path.clone(); + let refresh_interval_secs = self.refresh_interval_secs; + let worker_name = self.name().to_string(); + + tokio::spawn(async move { + // Store upstreams_path globally for ACME requests + set_upstreams_path(upstreams_path.clone()); + + // Initial fetch on startup - download all certificates immediately + log::info!("[{}] Starting certificate download from Redis on service startup...", worker_name); + match fetch_certificates_from_redis(&certificate_path, &upstreams_path).await { + Ok(_) => { + log::info!("[{}] Successfully downloaded all certificates from Redis on startup", worker_name); + } + Err(e) => { + log::warn!("[{}] Failed to fetch certificates from Redis on startup: {}", worker_name, e); + log::warn!("[{}] Will retry on next scheduled interval", worker_name); + } + } + + // Set up periodic refresh interval + let mut refresh_interval = interval(Duration::from_secs(refresh_interval_secs)); + // Skip the first tick since we already fetched on startup + refresh_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + // Set up periodic expiration check (every 6 hours) + let mut expiration_check_interval = interval(Duration::from_secs(6 * 60 * 60)); // 6 hours + expiration_check_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + // Skip the first tick - we'll check after the first certificate fetch + let mut first_expiration_check = true; + + loop { + tokio::select! { + _ = shutdown.changed() => { + if *shutdown.borrow() { + break; + } + } + _ = refresh_interval.tick() => { + log::debug!("[{}] Periodic certificate refresh triggered", worker_name); + // Update upstreams_path in case it changed + set_upstreams_path(upstreams_path.clone()); + if let Err(e) = fetch_certificates_from_redis(&certificate_path, &upstreams_path).await { + log::warn!("[{}] Failed to fetch certificates from Redis: {}", worker_name, e); + } + } + _ = expiration_check_interval.tick() => { + if first_expiration_check { + first_expiration_check = false; + continue; // Skip first check, wait for next interval + } + log::info!("[{}] Periodic certificate expiration check triggered", worker_name); + // Update upstreams_path in case it changed + set_upstreams_path(upstreams_path.clone()); + if let Err(e) = check_and_renew_expiring_certificates(&upstreams_path).await { + log::warn!("[{}] Failed to check certificate expiration: {}", worker_name, e); + } + } + } + } + + log::info!("[{}] Certificate fetcher task stopped", worker_name); + }) + } +} + +/// Start a background task that periodically fetches SSL certificates from Redis +/// This is kept for backward compatibility +pub fn start_certificate_fetcher( + certificate_path: String, + upstreams_path: String, + refresh_interval_secs: u64, + shutdown: watch::Receiver, +) -> tokio::task::JoinHandle<()> { + let worker = CertificateWorker::new(certificate_path, upstreams_path, refresh_interval_secs); + worker.run(shutdown) +} + +/// Fetch domains from upstreams.yaml file (source of truth) +async fn fetch_domains_from_upstreams(upstreams_path: &str) -> Result> { + use serde_yaml; + use std::path::PathBuf; + + let path = PathBuf::from(upstreams_path); + + // Read and parse upstreams.yaml + let yaml_content = tokio::fs::read_to_string(&path) + .await + .with_context(|| format!("Failed to read upstreams file: {:?}", path))?; + + let parsed: crate::utils::structs::Config = serde_yaml::from_str(&yaml_content) + .with_context(|| format!("Failed to parse upstreams YAML: {:?}", path))?; + + let mut domains = Vec::new(); + + if let Some(upstreams) = &parsed.upstreams { + for (hostname, _host_config) in upstreams { + domains.push(hostname.clone()); + } + } + + log::info!("Found {} domain(s) in upstreams.yaml: {:?}", domains.len(), domains); + Ok(domains) +} + + +/// Fetch SSL certificates from Redis for domains listed in upstreams.yaml +async fn fetch_certificates_from_redis(certificate_path: &str, upstreams_path: &str) -> Result<()> { + let redis_manager = RedisManager::get() + .context("Redis manager not initialized")?; + + // Parse upstreams.yaml to get domains and their certificate mappings + use serde_yaml; + use std::path::PathBuf; + let path = PathBuf::from(upstreams_path); + let yaml_content = tokio::fs::read_to_string(&path) + .await + .with_context(|| format!("Failed to read upstreams file: {:?}", path))?; + let parsed: crate::utils::structs::Config = serde_yaml::from_str(&yaml_content) + .with_context(|| format!("Failed to parse upstreams YAML: {:?}", path))?; + + // Build mapping of domain -> certificate_name (or None if not specified) + // Only include domains that need certificates (have ACME config or ssl_enabled: true) + let mut domain_cert_map: Vec<(String, Option)> = Vec::new(); + if let Some(upstreams) = &parsed.upstreams { + for (hostname, host_config) in upstreams { + // Only process domains that need certificates + if !host_config.needs_certificate() { + log::debug!("Skipping certificate check for domain {} (no ACME config and ssl_enabled: false)", hostname); + continue; + } + let cert_name = host_config.certificate.clone(); + domain_cert_map.push((hostname.clone(), cert_name)); + } + } + + if domain_cert_map.is_empty() { + log::warn!("No domains found in upstreams.yaml, skipping certificate fetch"); + return Ok(()); + } + + log::info!("Checking certificates for {} domain(s) from Redis (will skip download if hashes match)", domain_cert_map.len()); + + let mut connection = redis_manager.get_connection(); + let mut certificate_configs = Vec::new(); + let mut skipped_count = 0; + let mut downloaded_count = 0; + let mut missing_count = 0; + let cert_dir = std::path::Path::new(certificate_path); + + // Create certificate directory if it doesn't exist + if !cert_dir.exists() { + log::info!("Creating certificate directory: {}", certificate_path); + std::fs::create_dir_all(cert_dir) + .context(format!("Failed to create certificate directory: {}", certificate_path))?; + log::info!("Certificate directory created: {}", certificate_path); + } else { + log::debug!("Certificate directory already exists: {}", certificate_path); + } + + for (domain, cert_name_opt) in &domain_cert_map { + // Use certificate name if specified, otherwise use domain name + let cert_name = cert_name_opt.as_ref().unwrap_or(domain); + // Normalize certificate name (remove wildcard prefix if present) + let normalized_cert_name = cert_name.strip_prefix("*.").unwrap_or(cert_name); + + // Check certificate hash from Redis before downloading + // Get prefix from RedisManager + let prefix = RedisManager::get() + .map(|rm| rm.get_prefix().to_string()) + .unwrap_or_else(|_| "ssl-storage".to_string()); + let hash_key = format!("{}:{}:metadata:certificate_hash", prefix, normalized_cert_name); + let remote_hash: Option = redis::cmd("GET") + .arg(&hash_key) + .query_async(&mut connection) + .await + .context(format!("Failed to get certificate hash for domain: {}", domain))?; + + // Check in-memory cache first for local hash + let hash_cache = get_certificate_hash_cache(); + + // Get file paths first - use normalized certificate name for file naming + // This matches the filename format used by ACME (which uses normalized domain) + // ACME saves as: {sanitized_domain}.crt where domain is normalized (wildcard prefix removed) + let sanitized_cert_name = normalized_cert_name.replace('.', "_").replace('*', "_"); + let cert_path = cert_dir.join(format!("{}.crt", sanitized_cert_name)); + let key_path = cert_dir.join(format!("{}.key", sanitized_cert_name)); + + // Check local certificates first before checking Redis + // This allows using local certificates even if Redis is unavailable or certificate not in Redis yet + let local_hash = if cert_path.exists() && key_path.exists() { + // Local files exist - check cache first, then calculate if needed + let cached_hash = { + let cache = hash_cache.read().await; + cache.get(cert_name).cloned() + }; + + if let Some(cached_hash) = cached_hash { + Some(cached_hash) + } else if let Ok(calculated_hash) = calculate_local_hash(&cert_path, &key_path) { + // Calculate from files and cache it + let mut cache = hash_cache.write().await; + cache.insert(cert_name.to_string(), calculated_hash.clone()); + Some(calculated_hash) + } else { + None + } + } else { + // Files don't exist - clear cache + let mut cache = hash_cache.write().await; + cache.remove(cert_name); + None + }; + + // Determine if we need to download + // Priority: Use local certificate if available and valid, otherwise check Redis + let should_download = if let Some(local) = &local_hash { + // Local certificate exists - check if it matches Redis (if Redis hash is available) + if let Some(remote) = &remote_hash { + if remote == local { + // Hashes match - use local certificate + log::debug!("Certificate hash matches for domain: {} (hash: {}), using local certificate", domain, remote); + // Add existing certificate to config without re-downloading + certificate_configs.push(CertificateConfig { + cert_path: cert_path.to_string_lossy().to_string(), + key_path: key_path.to_string_lossy().to_string(), + }); + skipped_count += 1; + log::debug!("Added existing certificate config for domain: {} -> cert: {}, key: {}", + domain, cert_path.display(), key_path.display()); + false // Don't download + } else { + log::debug!("Certificate hash mismatch for domain: {} (remote: {}, local: {}), downloading new certificate from Redis", domain, remote, local); + true // Download - hash changed in Redis + } + } else { + // No Redis hash, but we have local certificate - use it + log::debug!("No Redis hash found for domain: {}, but local certificate exists, using local certificate", domain); + // Add existing certificate to config without re-downloading + certificate_configs.push(CertificateConfig { + cert_path: cert_path.to_string_lossy().to_string(), + key_path: key_path.to_string_lossy().to_string(), + }); + skipped_count += 1; + log::debug!("Added existing local certificate config for domain: {} -> cert: {}, key: {}", + domain, cert_path.display(), key_path.display()); + false // Don't download - use local certificate + } + } else if remote_hash.is_some() { + // No local certificate, but Redis has one - download it + log::debug!("No local certificate found for domain: {}, but Redis hash exists, downloading from Redis", domain); + true // Download from Redis + } else { + // No local certificate and no Redis hash - check if certificate exists in Redis + log::debug!("No local certificate and no Redis hash for domain: {}, will check if certificate exists in Redis", domain); + true // Check Redis for certificate + }; + + // Skip download if not needed + if !should_download { + continue; + } + + // Fetch fullchain and private key from Redis + // Get prefix from RedisManager + let prefix = RedisManager::get() + .map(|rm| rm.get_prefix().to_string()) + .unwrap_or_else(|_| "ssl-storage".to_string()); + // Redis stores certificates with keys: + // - {prefix}:{cert_name}:live:fullchain + // - {prefix}:{cert_name}:live:privkey + let fullchain_key = format!("{}:{}:live:fullchain", prefix, normalized_cert_name); + let privkey_key = format!("{}:{}:live:privkey", prefix, normalized_cert_name); + + log::info!("Fetching certificate for domain: {} (using cert: {}, normalized: {}, prefix: {})", domain, cert_name, normalized_cert_name, prefix); + log::info!("Fullchain key: '{}', Privkey key: '{}'", fullchain_key, privkey_key); + + let fullchain: Option> = redis::cmd("GET") + .arg(&fullchain_key) + .query_async(&mut connection) + .await + .context(format!("Failed to get fullchain for domain: {}", domain))?; + + let privkey: Option> = redis::cmd("GET") + .arg(&privkey_key) + .query_async(&mut connection) + .await + .context(format!("Failed to get private key for domain: {}", domain))?; + + log::info!("Redis GET results for domain {}: fullchain={}, privkey={}", + domain, + if fullchain.is_some() { "Some" } else { "None" }, + if privkey.is_some() { "Some" } else { "None" } + ); + + match (fullchain, privkey) { + (Some(fullchain_bytes), Some(privkey_bytes)) => { + // Validate PEM format + let fullchain_str = match String::from_utf8(fullchain_bytes.clone()) { + Ok(s) => s, + Err(_) => { + log::warn!("Fullchain for domain {} is not valid UTF-8", domain); + continue; + } + }; + + let privkey_str = match String::from_utf8(privkey_bytes.clone()) { + Ok(s) => s, + Err(_) => { + log::warn!("Private key for domain {} is not valid UTF-8", domain); + continue; + } + }; + + if !fullchain_str.contains("-----BEGIN CERTIFICATE-----") { + log::warn!("Fullchain for domain {} does not appear to be in PEM format", domain); + continue; + } + if !privkey_str.contains("-----BEGIN") { + log::warn!("Private key for domain {} does not appear to be in PEM format", domain); + continue; + } + + // Write certificates to certificate directory + // Use sanitized certificate name for file names (already set above) + // cert_path and key_path are already set using cert_name + + log::debug!("Writing certificate to: {} and key to: {}", cert_path.display(), key_path.display()); + + // Write fullchain to file + // Normalize the fullchain to ensure proper PEM format: + // - Ensure newline between certificates (END CERTIFICATE and BEGIN CERTIFICATE) + // - Ensure file ends with newline + let normalized_fullchain = normalize_pem_chain(&fullchain_str); + let mut cert_file = std::fs::File::create(&cert_path) + .context(format!("Failed to create certificate file for domain: {} at path: {}", domain, cert_path.display()))?; + cert_file.write_all(normalized_fullchain.as_bytes()) + .context(format!("Failed to write certificate file for domain: {} to path: {}", domain, cert_path.display()))?; + cert_file.sync_all() + .context(format!("Failed to sync certificate file for domain: {} at path: {}", domain, cert_path.display()))?; + + // Write private key to file + // Normalize the key to ensure proper PEM format + let normalized_key = normalize_pem_chain(&privkey_str); + let mut key_file = std::fs::File::create(&key_path) + .context(format!("Failed to create key file for domain: {} at path: {}", domain, key_path.display()))?; + key_file.write_all(normalized_key.as_bytes()) + .context(format!("Failed to write key file for domain: {} to path: {}", domain, key_path.display()))?; + key_file.sync_all() + .context(format!("Failed to sync key file for domain: {} at path: {}", domain, key_path.display()))?; + + // Calculate hash from raw bytes (before normalization) to match Redis hash calculation + // Redis calculates hash from: fullchain (raw bytes) + key (raw bytes) + // We need to match this exactly, not from normalized files + use sha2::{Sha256, Digest}; + let mut hasher = Sha256::new(); + hasher.update(&fullchain_bytes); + hasher.update(&privkey_bytes); + let calculated_hash = format!("{:x}", hasher.finalize()); + + downloaded_count += 1; + log::info!("Successfully downloaded and saved certificate for domain: {} to {}", domain, cert_path.display()); + + // Verify files were written correctly + if !cert_path.exists() { + log::warn!("Certificate file does not exist after write: {}", cert_path.display()); + continue; + } + if !key_path.exists() { + log::warn!("Key file does not exist after write: {}", key_path.display()); + continue; + } + + // Store local hash in memory cache after successful download + // Use the hash calculated from raw bytes (matching Redis calculation) + let cert_key = cert_name.to_string(); + let hash_cache = get_certificate_hash_cache(); + let mut cache = hash_cache.write().await; + cache.insert(cert_key, calculated_hash.clone()); + log::debug!("Stored local hash in memory cache for domain: {} -> {} (calculated from raw bytes, matching Redis)", domain, calculated_hash); + + // Verify hash matches Redis hash + if let Some(remote_hash) = &remote_hash { + if calculated_hash != *remote_hash { + log::warn!("Hash mismatch after download for domain {}: calculated={}, redis={}. This should not happen!", + domain, calculated_hash, remote_hash); + } else { + log::debug!("Hash verified: calculated hash matches Redis hash for domain: {}", domain); + } + } + + // Create certificate config entry + certificate_configs.push(CertificateConfig { + cert_path: cert_path.to_string_lossy().to_string(), + key_path: key_path.to_string_lossy().to_string(), + }); + log::debug!("Added certificate config for domain: {} -> cert: {}, key: {}", + domain, cert_path.display(), key_path.display()); + } + (Some(_), None) => { + missing_count += 1; + log::warn!("Certificate fullchain found but private key missing in Redis for domain: {} (cert: {}, key: {})", domain, cert_name, privkey_key); + // Request certificate from ACME using certificate name (with wildcard if configured) + if let Err(e) = request_certificate_from_acme(domain, normalized_cert_name, &certificate_path).await { + log::warn!("Failed to request certificate from ACME for domain {}: {}", domain, e); + } else { + log::debug!("Successfully requested certificate from ACME for domain: {} (certificate: {})", domain, cert_name); + } + } + (None, Some(_)) => { + missing_count += 1; + log::warn!("Certificate private key found but fullchain missing in Redis for domain: {} (cert: {}, key: {})", domain, cert_name, fullchain_key); + // Request certificate from ACME using certificate name (with wildcard if configured) + if let Err(e) = request_certificate_from_acme(domain, normalized_cert_name, &certificate_path).await { + log::warn!("Failed to request certificate from ACME for domain {}: {}", domain, e); + } else { + log::debug!("Successfully requested certificate from ACME for domain: {} (certificate: {})", domain, cert_name); + } + } + (None, None) => { + missing_count += 1; + log::warn!("Certificate not found in Redis for domain: {} (cert: {}, checked keys: fullchain='{}', privkey='{}')", + domain, cert_name, fullchain_key, privkey_key); + + // Try to list matching keys to help debug + let pattern = format!("{}:{}:*", prefix, normalized_cert_name); + let keys_result: Result, _> = redis::cmd("KEYS") + .arg(&pattern) + .query_async(&mut connection) + .await; + match keys_result { + Ok(keys) => { + if !keys.is_empty() { + log::debug!("Found {} matching keys for pattern '{}': {:?}", keys.len(), pattern, keys); + } else { + log::warn!("No keys found matching pattern '{}'", pattern); + } + } + Err(e) => { + log::debug!("Failed to list keys with pattern '{}': {}", pattern, e); + } + } + + // Request certificate from ACME server if enabled + // Use certificate name (with wildcard if configured) instead of hostname + if let Err(e) = request_certificate_from_acme(domain, normalized_cert_name, &certificate_path).await { + log::warn!("Failed to request certificate from ACME for domain {}: {}", domain, e); + } else { + log::debug!("Successfully requested certificate from ACME for domain: {} (certificate: {})", domain, cert_name); + } + } + } + } + + // Log summary + if skipped_count > 0 { + log::debug!("Skipped {} certificate(s) due to hash matches (using existing files)", skipped_count); + } + if downloaded_count > 0 { + log::info!("Downloaded {} new/updated certificate(s) from Redis", downloaded_count); + } + if missing_count > 0 { + log::warn!("{} certificate(s) not found in Redis", missing_count); + } + + if !certificate_configs.is_empty() { + log::debug!("Successfully processed {} certificate(s) ({} downloaded, {} skipped)", + certificate_configs.len(), downloaded_count, skipped_count); + log::debug!("Certificate configs to load: {:?}", + certificate_configs.iter().map(|c| format!("cert: {}, key: {}", c.cert_path, c.key_path)).collect::>()); + + // Update the certificate store + // Use "medium" as default TLS grade (can be made configurable) + // Default certificate is None for worker (can be made configurable later) + match Certificates::new(&certificate_configs, "medium", None) { + Some(certificates) => { + let store = get_certificate_store(); + let mut guard = store.write().await; + *guard = Some(Arc::new(certificates)); + log::debug!("Updated certificate store with {} certificates", certificate_configs.len()); + } + None => { + log::error!("Failed to create Certificates object from fetched configs. This usually means one or more certificate files are invalid or cannot be loaded."); + log::error!("Attempted to load {} certificate configs", certificate_configs.len()); + for config in &certificate_configs { + log::error!(" - cert: {}, key: {}", config.cert_path, config.key_path); + } + } + } + } else { + log::warn!("No certificates were processed. Check if certificates exist in Redis for the domains listed in upstreams.yaml, or if all certificates were skipped due to hash matches but files are missing"); + } + + Ok(()) +} + +/// Global ACME config store +static ACME_CONFIG: once_cell::sync::OnceCell>>> = once_cell::sync::OnceCell::new(); + +/// Global upstreams path store +static UPSTREAMS_PATH: once_cell::sync::OnceCell>>> = once_cell::sync::OnceCell::new(); + +/// Set the global ACME config (called from main.rs) +/// Can be called multiple times to update the config at runtime +pub fn set_acme_config(config: crate::cli::AcmeConfig) { + let store = ACME_CONFIG.get_or_init(|| Arc::new(std::sync::RwLock::new(None))); + let mut guard = store.write().unwrap(); + let was_development = guard.as_ref().map(|c| c.development).unwrap_or(false); + let is_development = config.development; + + // Extract values for logging before moving config + let enabled = config.enabled; + let port = config.port; + + // Log if development mode changes + if was_development != is_development { + log::warn!("ACME development mode changed: {} -> {}. Existing certificates may have been issued with the previous mode. Consider clearing certificates and restarting.", was_development, is_development); + } + + *guard = Some(config); + log::info!("ACME config updated: enabled={}, development={}, port={}", enabled, is_development, port); +} + +/// Set the global upstreams path (called from certificate worker) +fn set_upstreams_path(path: String) { + let store = UPSTREAMS_PATH.get_or_init(|| Arc::new(std::sync::RwLock::new(None))); + let mut guard = store.write().unwrap(); + *guard = Some(path); +} + +/// Get the global ACME config +pub async fn get_acme_config() -> Option { + let store = ACME_CONFIG.get()?; + let guard = tokio::task::spawn_blocking({ + let store = Arc::clone(store); + move || store.read().unwrap().clone() + }).await.ok()?; + guard +} + +/// Get the global upstreams path +async fn get_upstreams_path() -> Option { + let store = UPSTREAMS_PATH.get()?; + let guard = tokio::task::spawn_blocking({ + let store = Arc::clone(store); + move || store.read().unwrap().clone() + }).await.ok()?; + guard +} + +/// Request a certificate from ACME server for a domain +pub async fn request_certificate_from_acme( + domain: &str, + normalized_domain: &str, + _certificate_path: &str, +) -> Result<()> { + use crate::acme::{Config, ConfigOpts, request_cert}; + use std::path::PathBuf; + + // Check if ACME is enabled + let acme_config = match get_acme_config().await { + Some(config) if config.enabled => { + // Log which ACME server will be used + if config.development { + log::warn!("ACME development mode is ENABLED - certificates will be issued from Let's Encrypt STAGING server (not trusted by browsers)"); + } else { + log::info!("ACME development mode is DISABLED - certificates will be issued from Let's Encrypt PRODUCTION server"); + } + config + }, + Some(_) => { + log::debug!("ACME is disabled, skipping certificate request for domain: {}", domain); + return Ok(()); + } + None => { + log::debug!("ACME config not available, skipping certificate request for domain: {}", domain); + return Ok(()); + } + }; + + // Get email - use from config or default + let email = acme_config.email + .unwrap_or_else(|| "admin@example.com".to_string()); + + // Get Redis URL from RedisManager if available + let redis_url = crate::redis::RedisManager::get() + .ok() + .and_then(|_| { + // Use ACME config Redis URL, or try to get from RedisManager + acme_config.redis_url.clone() + .or_else(|| std::env::var("REDIS_URL").ok()) + }); + + // Read challenge type from upstreams.yaml + // Get upstreams path from global store (set by certificate worker) or use default + let upstreams_path = get_upstreams_path().await + .unwrap_or_else(|| "/root/synapse/upstreams.yaml".to_string()); + + // Determine the domain to request from ACME + // If normalized_domain is different from domain, it means a certificate name was specified + // In that case, we should request the certificate for the certificate domain, not the hostname + let (acme_domain, use_dns, domain_email, is_wildcard) = { + // Try to read challenge type and certificate config from upstreams.yaml + if let Ok(yaml_content) = tokio::fs::read_to_string(&upstreams_path).await { + if let Ok(parsed) = serde_yaml::from_str::(&yaml_content) { + if let Some(upstreams) = &parsed.upstreams { + if let Some(host_config) = upstreams.get(domain) { + // Check if a certificate name is specified (different from domain) + let cert_name_opt = host_config.certificate.as_ref(); + let acme_wildcard = host_config.acme.as_ref() + .map(|a| a.wildcard) + .unwrap_or(false); + + // Determine the domain to request from ACME + // If certificate name is specified and wildcard is true, request *.certificate_name + // Otherwise, if certificate name is specified, use it as-is + // If no certificate name, use the hostname domain + let requested_domain = if let Some(cert_name) = cert_name_opt { + // Certificate name is specified - use it for ACME request + if acme_wildcard && !cert_name.starts_with("*.") { + // Wildcard is set in config - request *.certificate_name + format!("*.{}", cert_name) + } else if cert_name.starts_with("*.") { + // Certificate name already has wildcard prefix + cert_name.clone() + } else { + // Certificate name without wildcard + cert_name.clone() + } + } else { + // No certificate name specified - use hostname domain + if acme_wildcard && !domain.starts_with("*.") { + // Wildcard is set but domain doesn't have *. prefix + format!("*.{}", normalized_domain) + } else { + domain.to_string() + } + }; + + // Get challenge type from ACME config in upstreams + let challenge_type = if let Some(acme_cfg) = &host_config.acme { + acme_cfg.challenge_type.clone() + } else { + // Auto-detect: DNS-01 for wildcard, HTTP-01 otherwise + if requested_domain.starts_with("*.") || acme_wildcard { + "dns-01".to_string() + } else { + "http-01".to_string() + } + }; + + let use_dns = challenge_type == "dns-01"; + let domain_email = host_config.acme.as_ref() + .and_then(|a| a.email.clone()) + .or_else(|| Some(email.clone())); + + let is_wildcard = requested_domain.starts_with("*.") || acme_wildcard; + + log::info!("ACME request: hostname={}, certificate={:?}, requested_domain={}, wildcard={}, challenge={}", + domain, cert_name_opt, requested_domain, is_wildcard, challenge_type); + (requested_domain, use_dns, domain_email, is_wildcard) + } else { + // Domain not found in upstreams, auto-detect + // If normalized_domain != domain, it means a certificate name was passed + let is_wildcard = domain.starts_with("*.") || normalized_domain != domain; + let requested_domain = if domain.starts_with("*.") { + domain.to_string() + } else if normalized_domain != domain { + // Certificate name was specified - assume wildcard + format!("*.{}", normalized_domain) + } else { + domain.to_string() + }; + let use_dns = is_wildcard; // Use DNS-01 for wildcard certificates + log::info!("Domain {} not found in upstreams.yaml, auto-detecting (requested: {}, wildcard: {}, dns: {})", + domain, requested_domain, is_wildcard, use_dns); + (requested_domain, use_dns, Some(email.clone()), is_wildcard) + } + } else { + // No upstreams, auto-detect + let is_wildcard = domain.starts_with("*.") || normalized_domain != domain; + let requested_domain = if domain.starts_with("*.") { + domain.to_string() + } else if normalized_domain != domain { + format!("*.{}", normalized_domain) + } else { + domain.to_string() + }; + let use_dns = is_wildcard; + (requested_domain, use_dns, Some(email.clone()), is_wildcard) + } + } else { + // Failed to parse, auto-detect + let is_wildcard = domain.starts_with("*.") || normalized_domain != domain; + let requested_domain = if domain.starts_with("*.") { + domain.to_string() + } else if normalized_domain != domain { + format!("*.{}", normalized_domain) + } else { + domain.to_string() + }; + let use_dns = is_wildcard; + (requested_domain, use_dns, Some(email.clone()), is_wildcard) + } + } else { + // Failed to read, auto-detect + let is_wildcard = domain.starts_with("*.") || normalized_domain != domain; + let requested_domain = if domain.starts_with("*.") { + domain.to_string() + } else if normalized_domain != domain { + format!("*.{}", normalized_domain) + } else { + domain.to_string() + }; + let use_dns = is_wildcard; + (requested_domain, use_dns, Some(email.clone()), is_wildcard) + } + }; + + // Create domain config for ACME + let mut domain_storage_path = PathBuf::from(&acme_config.storage_path); + domain_storage_path.push(normalized_domain); + + let mut cert_path = domain_storage_path.clone(); + cert_path.push("cert.pem"); + let mut key_path = domain_storage_path.clone(); + key_path.push("key.pem"); + let static_path = domain_storage_path.clone(); + + // Get Redis SSL config if available + let redis_ssl = crate::redis::RedisManager::get() + .ok() + .and_then(|_| { + // Try to get SSL config from global config if available + // For now, we'll use None and let it use defaults + None + }); + + let acme_config_internal = Config { + https_path: domain_storage_path, + cert_path, + key_path, + static_path, + opts: ConfigOpts { + ip: "127.0.0.1".to_string(), + port: acme_config.port, + domain: acme_domain.clone(), + email: domain_email, + https_dns: use_dns, + development: acme_config.development, + dns_lookup_max_attempts: Some(100), + dns_lookup_delay_seconds: Some(10), + storage_type: { + // Always use Redis (storage_type option is kept for compatibility but always uses Redis) + Some("redis".to_string()) + }, + redis_url, + lock_ttl_seconds: Some(900), + redis_ssl, + challenge_max_ttl_seconds: Some(3600), + }, + }; + + // Request certificate from ACME + log::info!("Requesting certificate from ACME: hostname={} -> certificate_domain={} (wildcard: {}, dns: {})", + domain, acme_domain, is_wildcard, use_dns); + + request_cert(&acme_config_internal).await + .context(format!("Failed to request certificate from ACME for hostname: {} (requested domain: {})", domain, acme_domain))?; + + log::info!("Certificate requested successfully from ACME: hostname={}, certificate_domain={}. It will be available in Redis after processing.", domain, acme_domain); + + // After requesting, the certificate should be in Redis (if using Redis storage) + // The next refresh cycle will pick it up automatically + + Ok(()) +} + +/// Check certificates for expiration and renew if expiring within 60 days +async fn check_and_renew_expiring_certificates(upstreams_path: &str) -> Result<()> { + use x509_parser::prelude::*; + use x509_parser::nom::Err as NomErr; + use rustls_pemfile::read_one; + use std::io::BufReader; + + // Get the list of domains from upstreams.yaml + let domains = fetch_domains_from_upstreams(upstreams_path).await?; + + if domains.is_empty() { + log::debug!("No domains found in upstreams.yaml, skipping expiration check"); + return Ok(()); + } + + log::info!("Checking certificate expiration for {} domain(s)", domains.len()); + + let redis_manager = RedisManager::get() + .context("Redis manager not initialized")?; + + let mut connection = redis_manager.get_connection(); + let mut renewed_count = 0; + let mut checked_count = 0; + + for domain in &domains { + let normalized_domain = domain.strip_prefix("*.").unwrap_or(domain); + + // Check if certificate exists in Redis + // Get prefix from RedisManager + let prefix = RedisManager::get() + .map(|rm| rm.get_prefix().to_string()) + .unwrap_or_else(|_| "ssl-storage".to_string()); + let fullchain_key = format!("{}:{}:live:fullchain", prefix, normalized_domain); + let fullchain: Option> = redis::cmd("GET") + .arg(&fullchain_key) + .query_async(&mut connection) + .await + .context(format!("Failed to get fullchain for domain: {}", domain))?; + + let fullchain_bytes = match fullchain { + Some(bytes) => bytes, + None => { + log::debug!("Certificate not found in Redis for domain: {}, skipping expiration check", domain); + continue; + } + }; + + // Parse the certificate to get expiration date + let fullchain_str = match String::from_utf8(fullchain_bytes.clone()) { + Ok(s) => s, + Err(_) => { + log::warn!("Fullchain for domain {} is not valid UTF-8, skipping expiration check", domain); + continue; + } + }; + + // Parse PEM to get the first certificate (domain cert) + let mut reader = BufReader::new(fullchain_str.as_bytes()); + let cert_der = match read_one(&mut reader) { + Ok(Some(rustls_pemfile::Item::X509Certificate(cert))) => cert, + Ok(_) => { + log::warn!("No X509 certificate found in fullchain for domain: {}", domain); + continue; + } + Err(e) => { + log::warn!("Failed to parse certificate for domain {}: {:?}", domain, e); + continue; + } + }; + + // Parse the DER certificate + let (_, x509_cert) = match X509Certificate::from_der(&cert_der) { + Ok(cert) => cert, + Err(NomErr::Error(e)) | Err(NomErr::Failure(e)) => { + log::warn!("Failed to parse X509 certificate for domain {}: {:?}", domain, e); + continue; + } + Err(_) => { + log::warn!("Unknown error parsing X509 certificate for domain: {}", domain); + continue; + } + }; + + // Get expiration date + let validity = x509_cert.validity(); + let not_after_offset = validity.not_after.to_datetime(); + let now = chrono::Utc::now(); + + // Convert OffsetDateTime to chrono::DateTime + let not_after = chrono::DateTime::::from_timestamp( + not_after_offset.unix_timestamp(), + 0 + ).unwrap_or_else(|| { + log::warn!("Failed to convert certificate expiration date for domain: {}", domain); + now + chrono::Duration::days(90) // Fallback to 90 days from now + }); + + // Calculate days until expiration + let expires_in = not_after - now; + let days_until_expiration = expires_in.num_days(); + + checked_count += 1; + + log::debug!("Certificate for domain {} expires in {} days (expires at: {})", + domain, days_until_expiration, not_after); + + // Check if certificate expires in less than 60 days + if days_until_expiration < 60 { + log::info!("Certificate for domain {} expires in {} days (< 60 days), starting renewal process", + domain, days_until_expiration); + + // Request renewal from ACME + let certificate_path = "/tmp/synapse-certs"; // Placeholder, will be stored in Redis + if let Err(e) = request_certificate_from_acme(domain, normalized_domain, certificate_path).await { + log::warn!("Failed to renew certificate for domain {}: {}", domain, e); + } else { + log::info!("Successfully initiated certificate renewal for domain: {}", domain); + renewed_count += 1; + } + } else { + log::debug!("Certificate for domain {} is still valid (expires in {} days)", + domain, days_until_expiration); + } + } + + log::info!("Certificate expiration check completed: {} checked, {} renewed", checked_count, renewed_count); + + Ok(()) +} + +/// Clear a specific certificate from both local filesystem and Redis +/// certificate_name: The certificate name (e.g., "kapnative.developnet.hu" or "*.kapnative.developnet.hu") +/// certificate_path: The path where certificates are stored locally +pub async fn clear_certificate(certificate_name: &str, certificate_path: &str) -> Result<()> { + // Normalize certificate name (remove wildcard prefix if present) + let normalized_cert_name = certificate_name.strip_prefix("*.").unwrap_or(certificate_name); + + // Sanitize certificate name for file naming (matches the format used when saving) + let sanitized_cert_name = normalized_cert_name.replace('.', "_").replace('*', "_"); + + let cert_dir = std::path::Path::new(certificate_path); + let cert_path = cert_dir.join(format!("{}.crt", sanitized_cert_name)); + let key_path = cert_dir.join(format!("{}.key", sanitized_cert_name)); + + log::info!("Clearing certificate: {} (normalized: {}, sanitized: {})", + certificate_name, normalized_cert_name, sanitized_cert_name); + + // Delete local certificate files + let mut deleted_local = false; + if cert_path.exists() { + match std::fs::remove_file(&cert_path) { + Ok(_) => { + log::info!("Deleted local certificate file: {}", cert_path.display()); + deleted_local = true; + } + Err(e) => { + log::warn!("Failed to delete local certificate file {}: {}", cert_path.display(), e); + } + } + } else { + log::debug!("Local certificate file does not exist: {}", cert_path.display()); + } + + if key_path.exists() { + match std::fs::remove_file(&key_path) { + Ok(_) => { + log::info!("Deleted local key file: {}", key_path.display()); + deleted_local = true; + } + Err(e) => { + log::warn!("Failed to delete local key file {}: {}", key_path.display(), e); + } + } + } else { + log::debug!("Local key file does not exist: {}", key_path.display()); + } + + // Clear from in-memory hash cache + let hash_cache = get_certificate_hash_cache(); + { + let mut cache = hash_cache.write().await; + cache.remove(certificate_name); + log::debug!("Removed certificate hash from in-memory cache: {}", certificate_name); + } + + // Delete from Redis + let redis_manager = match RedisManager::get() { + Ok(rm) => rm, + Err(e) => { + log::warn!("Redis manager not initialized, skipping Redis deletion: {}", e); + if deleted_local { + log::info!("Certificate cleared from local filesystem only (Redis not available)"); + } + return Ok(()); + } + }; + + let mut connection = redis_manager.get_connection(); + let prefix = redis_manager.get_prefix(); + + // Delete all certificate-related keys from Redis + // This includes all live certificates, metadata, and failure tracking + let keys_to_delete = vec![ + // Live certificate files + format!("{}:{}:live:fullchain", prefix, normalized_cert_name), + format!("{}:{}:live:privkey", prefix, normalized_cert_name), + format!("{}:{}:live:cert", prefix, normalized_cert_name), + format!("{}:{}:live:chain", prefix, normalized_cert_name), + // Metadata keys + format!("{}:{}:metadata:certificate_hash", prefix, normalized_cert_name), + format!("{}:{}:metadata:created_at", prefix, normalized_cert_name), + format!("{}:{}:metadata:cert_failure", prefix, normalized_cert_name), + format!("{}:{}:metadata:cert_failure_count", prefix, normalized_cert_name), + ]; + + let mut deleted_redis = 0; + for key in &keys_to_delete { + match redis::cmd("DEL").arg(key).query_async::(&mut connection).await { + Ok(count) => { + if count > 0 { + log::info!("Deleted Redis key: {}", key); + deleted_redis += count; + } else { + log::debug!("Redis key does not exist: {}", key); + } + } + Err(e) => { + log::warn!("Failed to delete Redis key {}: {}", key, e); + } + } + } + + // Also try to delete any archived certificates, challenge keys, lock keys, etc. (if they exist) + let patterns_to_clean = vec![ + format!("{}:{}:archive:*", prefix, normalized_cert_name), + format!("{}:{}:challenge:*", prefix, normalized_cert_name), + format!("{}:{}:dns-challenge", prefix, normalized_cert_name), + format!("{}:{}:lock", prefix, normalized_cert_name), + ]; + + for pattern in &patterns_to_clean { + match redis::cmd("KEYS").arg(pattern).query_async::>(&mut connection).await { + Ok(keys) => { + if !keys.is_empty() { + log::info!("Found {} key(s) matching pattern '{}', deleting...", keys.len(), pattern); + for key in &keys { + match redis::cmd("DEL").arg(key).query_async::(&mut connection).await { + Ok(count) => { + if count > 0 { + log::info!("Deleted Redis key: {}", key); + deleted_redis += count; + } + } + Err(e) => { + log::warn!("Failed to delete Redis key {}: {}", key, e); + } + } + } + } + } + Err(e) => { + log::debug!("Failed to list keys matching pattern '{}' (this is OK if none exist): {}", pattern, e); + } + } + } + + if deleted_local || deleted_redis > 0 { + log::info!("Successfully cleared certificate '{}': deleted {} local file(s), {} Redis key(s)", + certificate_name, if deleted_local { 2 } else { 0 }, deleted_redis); + } else { + log::warn!("Certificate '{}' not found in local filesystem or Redis", certificate_name); + } + + Ok(()) +} + diff --git a/src/worker/config.rs b/src/worker/config.rs index 2b18d85..3e41360 100644 --- a/src/worker/config.rs +++ b/src/worker/config.rs @@ -1,548 +1,548 @@ -use hyper::StatusCode; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::io::Read; -use flate2::read::GzDecoder; -use std::sync::{Arc, OnceLock, RwLock}; -use tokio::sync::watch; -use tokio::time::{interval, Duration, MissedTickBehavior}; -use crate::content_scanning::ContentScanningConfig; -use crate::http_client::get_global_reqwest_client; -use crate::worker::Worker; - -pub type Details = serde_json::Value; - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ConfigApiResponse { - pub success: bool, - pub config: Config, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Config { - pub access_rules: AccessRule, - pub waf_rules: WafRules, - #[serde(default)] - pub content_scanning: ContentScanningConfig, - pub created_at: String, - pub updated_at: String, - pub last_modified: String, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct AccessRule { - pub id: String, - pub name: String, - pub description: String, - pub allow: RuleSet, - pub block: RuleSet, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct WafRules { - pub rules: Vec, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct WafRule { - pub id: String, - pub name: String, - pub org_id: String, - pub description: String, - pub action: String, - pub expression: String, - #[serde(default)] - pub config: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct RateLimitConfig { - pub period: String, - pub duration: String, - pub requests: String, -} - -impl RateLimitConfig { - pub fn from_json(value: &serde_json::Value) -> Result { - // Parse from nested structure: {"rateLimit": {"period": "25", ...}} - if let Some(rate_limit_obj) = value.get("rateLimit") { - serde_json::from_value(rate_limit_obj.clone()) - .map_err(|e| e.to_string()) - } else { - Err("rateLimit field not found".to_string()) - } - } - - pub fn period_secs(&self) -> u64 { - self.period.parse().unwrap_or(60) - } - - pub fn duration_secs(&self) -> u64 { - self.duration.parse().unwrap_or(60) - } - - pub fn requests_count(&self) -> usize { - self.requests.parse().unwrap_or(100) - } -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct RuleSet { - pub asn: Vec>>, - pub country: Vec>>, - pub ips: Vec, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ErrorResponse { - pub details: Details, - pub error: String, - pub success: bool, -} - -// Global configuration store accessible across services -// Uses Arc to allow sharing config without cloning the entire struct -static GLOBAL_CONFIG: OnceLock>>>> = OnceLock::new(); - -pub fn global_config() -> Arc>>> { - GLOBAL_CONFIG - .get_or_init(|| Arc::new(RwLock::new(None))) - .clone() -} - -/// Set the global config, wrapping it in Arc for efficient sharing -pub fn set_global_config(cfg: Config) { - let store = global_config(); - if let Ok(mut guard) = store.write() { - *guard = Some(Arc::new(cfg)); - } -} - -/// Get a clone of the Arc for efficient access without cloning the entire config -pub fn get_global_config_arc() -> Option> { - let store = global_config(); - if let Ok(guard) = store.read() { - guard.clone() - } else { - None - } -} - -pub async fn fetch_config( - base_url: String, - api_key: String, -) -> Result> { - // Use shared HTTP client with keepalive instead of creating new client - let client = get_global_reqwest_client() - .map_err(|e| anyhow::anyhow!("Failed to get global HTTP client: {}", e))?; - - let url = format!("{}/config", base_url); - - let response = client - .get(&url) - .header("Authorization", format!("Bearer {}", api_key)) - .header("Accept-Encoding", "gzip") - .send() - .await?; - - let status = response.status(); - match status { - StatusCode::OK => { - // Check if response is gzipped by looking at Content-Encoding header first - let content_encoding = response.headers() - .get("content-encoding") - .and_then(|h| h.to_str().ok()) - .unwrap_or("") - .to_string(); // Convert to owned String to avoid borrow issues - - let bytes = response.bytes().await?; - - let json_text = if content_encoding.contains("gzip") || - (bytes.len() >= 2 && bytes[0] == 0x1f && bytes[1] == 0x8b) { - // Response is gzipped, decompress it - let mut decoder = GzDecoder::new(&bytes[..]); - let mut decompressed_bytes = Vec::new(); - decoder.read_to_end(&mut decompressed_bytes) - .map_err(|e| format!("Failed to decompress gzipped response: {}", e))?; - - // Check if the decompressed content is also gzipped (double compression) - let final_bytes = if decompressed_bytes.len() >= 2 && decompressed_bytes[0] == 0x1f && decompressed_bytes[1] == 0x8b { - let mut second_decoder = GzDecoder::new(&decompressed_bytes[..]); - let mut final_bytes = Vec::new(); - second_decoder.read_to_end(&mut final_bytes) - .map_err(|e| format!("Failed to decompress second gzip layer: {}", e))?; - final_bytes - } else { - decompressed_bytes - }; - - // Try to convert to UTF-8 string - match String::from_utf8(final_bytes) { - Ok(text) => text, - Err(e) => { - return Err(format!("Final decompressed response contains invalid UTF-8: {}", e).into()); - } - } - } else { - // Response is not gzipped, use as-is - String::from_utf8(bytes.to_vec()) - .map_err(|e| format!("Response contains invalid UTF-8: {}", e))? - }; - - // Check if response body is empty - let json_text = json_text.trim(); - if json_text.is_empty() { - return Err("API returned empty response body".into()); - } - - let body: ConfigApiResponse = serde_json::from_str(json_text) - .map_err(|e| { - let preview = if json_text.len() > 200 { - format!("{}...", &json_text[..200]) - } else { - json_text.to_string() - }; - format!("Failed to parse JSON response: {}. Response preview: {}", e, preview) - })?; - - // Update global config snapshot - set_global_config(body.config.clone()); - Ok(body) - } - StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND | StatusCode::INTERNAL_SERVER_ERROR => { - let response_text = response.text().await?; - let trimmed = response_text.trim(); - let status_code = status.as_u16(); - if trimmed.is_empty() { - return Err(format!("API returned empty response body with status {}", status_code).into()); - } - match serde_json::from_str::(trimmed) { - Ok(body) => Err(format!("API Error: {}", body.error).into()), - Err(e) => { - let preview = if trimmed.len() > 200 { - format!("{}...", &trimmed[..200]) - } else { - trimmed.to_string() - }; - Err(format!("API returned status {} but response is not valid JSON: {}. Response preview: {}", - status_code, e, preview).into()) - } - } - } - - status => Err(format!( - "Unexpected API status code: {} - {}", - status.as_u16(), - status.canonical_reason().unwrap_or("Unknown") - ) - .into()), - - } -} - -/// Fetch config and run a user-provided callback to apply it. -/// The callback can update WAF rules, BPF maps, caches, etc. -pub async fn fetch_and_apply( - base_url: String, - api_key: String, - mut on_config: F, -) -> Result<(), Box> -where - F: FnMut(&ConfigApiResponse) -> Result<(), Box>, -{ - let resp = fetch_config(base_url, api_key).await?; - on_config(&resp)?; - Ok(()) -} - -/// Config worker that periodically fetches and applies configuration from API -pub struct ConfigWorker { - base_url: String, - api_key: String, - refresh_interval_secs: u64, - skels: Vec>>, - security_rules_config_path: std::path::PathBuf, - is_agent_mode: bool, - nftables_firewall: Option>>, - iptables_firewall: Option>>, -} - -impl ConfigWorker { - pub fn new(base_url: String, api_key: String, refresh_interval_secs: u64, skels: Vec>>, security_rules_config_path: std::path::PathBuf) -> Self { - Self { - base_url, - api_key, - refresh_interval_secs, - skels, - security_rules_config_path, - is_agent_mode: false, - nftables_firewall: None, - iptables_firewall: None, - } - } - - pub fn with_agent_mode(mut self, is_agent_mode: bool) -> Self { - self.is_agent_mode = is_agent_mode; - self - } - - pub fn with_nftables(mut self, nft_fw: Option>>) -> Self { - self.nftables_firewall = nft_fw; - self - } - - pub fn with_iptables(mut self, ipt_fw: Option>>) -> Self { - self.iptables_firewall = ipt_fw; - self - } -} - -/// Load configuration from a local YAML file -pub async fn load_config_from_file(path: &std::path::PathBuf) -> Result> { - use anyhow::Context; - let content = tokio::fs::read_to_string(path).await - .with_context(|| format!("Failed to read security rules file: {:?}", path))?; - let config: Config = serde_yaml::from_str(&content) - .with_context(|| format!("Failed to parse security rules YAML from file: {:?}", path))?; - Ok(ConfigApiResponse { success: true, config }) -} - -impl Worker for ConfigWorker { - fn name(&self) -> &str { - "config" - } - - fn run(&self, mut shutdown: watch::Receiver) -> tokio::task::JoinHandle<()> { - let base_url = self.base_url.clone(); - let api_key = self.api_key.clone(); - let refresh_interval_secs = self.refresh_interval_secs; - let skels = self.skels.clone(); - let security_rules_config_path = self.security_rules_config_path.clone(); - let worker_name = self.name().to_string(); - let is_agent_mode = self.is_agent_mode; - let nftables_firewall = self.nftables_firewall.clone(); - let iptables_firewall = self.iptables_firewall.clone(); - - // Create previous rules state for nftables/iptables - let nft_previous_rules: Arc>> = - Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); - let nft_previous_rules_v6: Arc>> = - Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); - let ipt_previous_rules: Arc>> = - Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); - let ipt_previous_rules_v6: Arc>> = - Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); - - tokio::spawn(async move { - // Check if running in local mode (no API key) - if api_key.is_empty() { - // LOCAL MODE: Load config from file once at startup - log::info!("[{}] No API key provided, loading config from local file...", worker_name); - match load_config_from_file(&security_rules_config_path).await { - Ok(config_response) => { - log::info!("[{}] Successfully loaded config from local file (access_rules: {}, waf_rules: {})", - worker_name, - config_response.config.access_rules.allow.ips.len() + config_response.config.access_rules.block.ips.len(), - config_response.config.waf_rules.rules.len() - ); - - // Apply rules to BPF maps (skip WAF update in agent mode) - if !skels.is_empty() { - if let Err(e) = crate::access_rules::apply_rules_from_global_with_state(&skels, is_agent_mode) { - log::error!("[{}] Failed to apply rules from local config: {}", worker_name, e); - } - } else if let Some(ref nft_fw) = nftables_firewall { - // Apply rules to nftables if BPF is not available - if let Err(e) = crate::access_rules::apply_rules_nftables(nft_fw, &nft_previous_rules, &nft_previous_rules_v6) { - log::error!("[{}] Failed to apply rules to nftables: {}", worker_name, e); - } - } else if let Some(ref ipt_fw) = iptables_firewall { - // Apply rules to iptables if BPF and nftables are not available - if let Err(e) = crate::access_rules::apply_rules_iptables(ipt_fw, &ipt_previous_rules, &ipt_previous_rules_v6) { - log::error!("[{}] Failed to apply rules to iptables: {}", worker_name, e); - } - } - } - Err(e) => { - log::error!("[{}] Failed to load config from local file: {}", worker_name, e); - } - } - - // Set up file watcher for automatic reloading - log::info!("[{}] Running in local config mode - watching {} for changes", - worker_name, security_rules_config_path.display()); - - // Set up file watcher - use notify::{Config as NotifyConfig, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; - use notify::event::ModifyKind; - use std::time::{Duration, Instant}; - use std::path::Path; - use tokio::task; - - let file_path = security_rules_config_path.clone(); - let parent_dir = Path::new(&file_path).parent().unwrap().to_path_buf(); - let (local_tx, mut local_rx) = tokio::sync::mpsc::channel::>(1); - - let _watcher_handle = task::spawn_blocking({ - let parent_dir = parent_dir.clone(); - move || { - let mut watcher = RecommendedWatcher::new( - move |res| { - let _ = local_tx.blocking_send(res); - }, - NotifyConfig::default(), - ).expect("Failed to create file watcher"); - watcher.watch(&parent_dir, RecursiveMode::NonRecursive) - .expect("Failed to watch security rules directory"); - let (_rtx, mut rrx) = tokio::sync::mpsc::channel::(1); - let _ = rrx.blocking_recv(); - } - }); - - let mut last_reload = Instant::now(); - - loop { - tokio::select! { - _ = shutdown.changed() => { - if *shutdown.borrow() { - break; - } - } - event = local_rx.recv() => { - if let Some(event) = event { - match event { - Ok(e) => { - match e.kind { - EventKind::Modify(ModifyKind::Data(_)) | EventKind::Create(..) | EventKind::Remove(..) => { - // Check if the modified file is the security_rules file - let is_target_file = e.paths.iter().any(|p| { - p.to_str().map_or(false, |s| { - s.ends_with("yaml") && - p.file_name() == file_path.file_name() - }) - }); - - if is_target_file && last_reload.elapsed() > Duration::from_secs(2) { - last_reload = Instant::now(); - log::info!("[{}] Security rules file changed, reloading...", worker_name); - - match load_config_from_file(&file_path).await { - Ok(config_response) => { - log::info!("[{}] Config reloaded successfully (access_rules: {}, waf_rules: {})", - worker_name, - config_response.config.access_rules.allow.ips.len() + config_response.config.access_rules.block.ips.len(), - config_response.config.waf_rules.rules.len() - ); - - // Apply rules to BPF maps (skip WAF update in agent mode) - if !skels.is_empty() { - if let Err(e) = crate::access_rules::apply_rules_from_global_with_state(&skels, is_agent_mode) { - log::error!("[{}] Failed to apply rules from reloaded config: {}", worker_name, e); - } - } else if let Some(ref nft_fw) = nftables_firewall { - // Apply rules to nftables if BPF is not available - if let Err(e) = crate::access_rules::apply_rules_nftables(nft_fw, &nft_previous_rules, &nft_previous_rules_v6) { - log::error!("[{}] Failed to apply rules to nftables: {}", worker_name, e); - } - } else if let Some(ref ipt_fw) = iptables_firewall { - // Apply rules to iptables if BPF and nftables are not available - if let Err(e) = crate::access_rules::apply_rules_iptables(ipt_fw, &ipt_previous_rules, &ipt_previous_rules_v6) { - log::error!("[{}] Failed to apply rules to iptables: {}", worker_name, e); - } - } - } - Err(e) => { - log::error!("[{}] Failed to reload config from local file: {}", worker_name, e); - } - } - } - } - _ => {} - } - } - Err(e) => { - log::error!("[{}] File watch error: {:?}", worker_name, e); - } - } - } - } - } - } - } else { - // API MODE: Fetch from API periodically - match fetch_config(base_url.clone(), api_key.clone()).await { - Ok(_config_response) => { - // Apply rules to BPF maps after fetching config (skip WAF update in agent mode) - if !skels.is_empty() { - if let Err(e) = crate::access_rules::apply_rules_from_global_with_state(&skels, is_agent_mode) { - log::error!("[{}] Failed to apply rules from initial config: {}", worker_name, e); - } - } else if let Some(ref nft_fw) = nftables_firewall { - // Apply rules to nftables if BPF is not available - if let Err(e) = crate::access_rules::apply_rules_nftables(nft_fw, &nft_previous_rules, &nft_previous_rules_v6) { - log::error!("[{}] Failed to apply rules to nftables: {}", worker_name, e); - } - } else if let Some(ref ipt_fw) = iptables_firewall { - // Apply rules to iptables if BPF and nftables are not available - if let Err(e) = crate::access_rules::apply_rules_iptables(ipt_fw, &ipt_previous_rules, &ipt_previous_rules_v6) { - log::error!("[{}] Failed to apply rules to iptables: {}", worker_name, e); - } - } - } - Err(e) => { - log::warn!("[{}] Failed to fetch initial config from API: {}", worker_name, e); - log::warn!("[{}] Will retry on next scheduled interval", worker_name); - } - } - - // Set up periodic refresh interval - let mut interval = interval(Duration::from_secs(refresh_interval_secs)); - interval.set_missed_tick_behavior(MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = shutdown.changed() => { - if *shutdown.borrow() { - break; - } - } - _ = interval.tick() => { - log::debug!("[{}] Periodic config refresh triggered", worker_name); - match fetch_config(base_url.clone(), api_key.clone()).await { - Ok(config_response) => { - log::debug!("[{}] Config refreshed successfully (waf_rules: {}, access_rules: {})", - worker_name, - config_response.config.waf_rules.rules.len(), - config_response.config.access_rules.allow.ips.len() + config_response.config.access_rules.block.ips.len() - ); - - // Apply rules to BPF maps after fetching config (skip WAF update in agent mode) - if !skels.is_empty() { - if let Err(e) = crate::access_rules::apply_rules_from_global_with_state(&skels, is_agent_mode) { - log::error!("[{}] Failed to apply rules from config: {}", worker_name, e); - } - } else if let Some(ref nft_fw) = nftables_firewall { - // Apply rules to nftables if BPF is not available - if let Err(e) = crate::access_rules::apply_rules_nftables(nft_fw, &nft_previous_rules, &nft_previous_rules_v6) { - log::error!("[{}] Failed to apply rules to nftables: {}", worker_name, e); - } - } else if let Some(ref ipt_fw) = iptables_firewall { - // Apply rules to iptables if BPF and nftables are not available - if let Err(e) = crate::access_rules::apply_rules_iptables(ipt_fw, &ipt_previous_rules, &ipt_previous_rules_v6) { - log::error!("[{}] Failed to apply rules to iptables: {}", worker_name, e); - } - } - } - Err(e) => { - log::warn!("[{}] Failed to fetch config from API: {}", worker_name, e); - } - } - } - } - } - } - - log::info!("[{}] Config fetcher task stopped", worker_name); - }) - } -} +use hyper::StatusCode; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::io::Read; +use flate2::read::GzDecoder; +use std::sync::{Arc, OnceLock, RwLock}; +use tokio::sync::watch; +use tokio::time::{interval, Duration, MissedTickBehavior}; +use crate::content_scanning::ContentScanningConfig; +use crate::http_client::get_global_reqwest_client; +use crate::worker::Worker; + +pub type Details = serde_json::Value; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ConfigApiResponse { + pub success: bool, + pub config: Config, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Config { + pub access_rules: AccessRule, + pub waf_rules: WafRules, + #[serde(default)] + pub content_scanning: ContentScanningConfig, + pub created_at: String, + pub updated_at: String, + pub last_modified: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct AccessRule { + pub id: String, + pub name: String, + pub description: String, + pub allow: RuleSet, + pub block: RuleSet, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct WafRules { + pub rules: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct WafRule { + pub id: String, + pub name: String, + pub org_id: String, + pub description: String, + pub action: String, + pub expression: String, + #[serde(default)] + pub config: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct RateLimitConfig { + pub period: String, + pub duration: String, + pub requests: String, +} + +impl RateLimitConfig { + pub fn from_json(value: &serde_json::Value) -> Result { + // Parse from nested structure: {"rateLimit": {"period": "25", ...}} + if let Some(rate_limit_obj) = value.get("rateLimit") { + serde_json::from_value(rate_limit_obj.clone()) + .map_err(|e| e.to_string()) + } else { + Err("rateLimit field not found".to_string()) + } + } + + pub fn period_secs(&self) -> u64 { + self.period.parse().unwrap_or(60) + } + + pub fn duration_secs(&self) -> u64 { + self.duration.parse().unwrap_or(60) + } + + pub fn requests_count(&self) -> usize { + self.requests.parse().unwrap_or(100) + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct RuleSet { + pub asn: Vec>>, + pub country: Vec>>, + pub ips: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ErrorResponse { + pub details: Details, + pub error: String, + pub success: bool, +} + +// Global configuration store accessible across services +// Uses Arc to allow sharing config without cloning the entire struct +static GLOBAL_CONFIG: OnceLock>>>> = OnceLock::new(); + +pub fn global_config() -> Arc>>> { + GLOBAL_CONFIG + .get_or_init(|| Arc::new(RwLock::new(None))) + .clone() +} + +/// Set the global config, wrapping it in Arc for efficient sharing +pub fn set_global_config(cfg: Config) { + let store = global_config(); + if let Ok(mut guard) = store.write() { + *guard = Some(Arc::new(cfg)); + } +} + +/// Get a clone of the Arc for efficient access without cloning the entire config +pub fn get_global_config_arc() -> Option> { + let store = global_config(); + if let Ok(guard) = store.read() { + guard.clone() + } else { + None + } +} + +pub async fn fetch_config( + base_url: String, + api_key: String, +) -> Result> { + // Use shared HTTP client with keepalive instead of creating new client + let client = get_global_reqwest_client() + .map_err(|e| anyhow::anyhow!("Failed to get global HTTP client: {}", e))?; + + let url = format!("{}/config", base_url); + + let response = client + .get(&url) + .header("Authorization", format!("Bearer {}", api_key)) + .header("Accept-Encoding", "gzip") + .send() + .await?; + + let status = response.status(); + match status { + StatusCode::OK => { + // Check if response is gzipped by looking at Content-Encoding header first + let content_encoding = response.headers() + .get("content-encoding") + .and_then(|h| h.to_str().ok()) + .unwrap_or("") + .to_string(); // Convert to owned String to avoid borrow issues + + let bytes = response.bytes().await?; + + let json_text = if content_encoding.contains("gzip") || + (bytes.len() >= 2 && bytes[0] == 0x1f && bytes[1] == 0x8b) { + // Response is gzipped, decompress it + let mut decoder = GzDecoder::new(&bytes[..]); + let mut decompressed_bytes = Vec::new(); + decoder.read_to_end(&mut decompressed_bytes) + .map_err(|e| format!("Failed to decompress gzipped response: {}", e))?; + + // Check if the decompressed content is also gzipped (double compression) + let final_bytes = if decompressed_bytes.len() >= 2 && decompressed_bytes[0] == 0x1f && decompressed_bytes[1] == 0x8b { + let mut second_decoder = GzDecoder::new(&decompressed_bytes[..]); + let mut final_bytes = Vec::new(); + second_decoder.read_to_end(&mut final_bytes) + .map_err(|e| format!("Failed to decompress second gzip layer: {}", e))?; + final_bytes + } else { + decompressed_bytes + }; + + // Try to convert to UTF-8 string + match String::from_utf8(final_bytes) { + Ok(text) => text, + Err(e) => { + return Err(format!("Final decompressed response contains invalid UTF-8: {}", e).into()); + } + } + } else { + // Response is not gzipped, use as-is + String::from_utf8(bytes.to_vec()) + .map_err(|e| format!("Response contains invalid UTF-8: {}", e))? + }; + + // Check if response body is empty + let json_text = json_text.trim(); + if json_text.is_empty() { + return Err("API returned empty response body".into()); + } + + let body: ConfigApiResponse = serde_json::from_str(json_text) + .map_err(|e| { + let preview = if json_text.len() > 200 { + format!("{}...", &json_text[..200]) + } else { + json_text.to_string() + }; + format!("Failed to parse JSON response: {}. Response preview: {}", e, preview) + })?; + + // Update global config snapshot + set_global_config(body.config.clone()); + Ok(body) + } + StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND | StatusCode::INTERNAL_SERVER_ERROR => { + let response_text = response.text().await?; + let trimmed = response_text.trim(); + let status_code = status.as_u16(); + if trimmed.is_empty() { + return Err(format!("API returned empty response body with status {}", status_code).into()); + } + match serde_json::from_str::(trimmed) { + Ok(body) => Err(format!("API Error: {}", body.error).into()), + Err(e) => { + let preview = if trimmed.len() > 200 { + format!("{}...", &trimmed[..200]) + } else { + trimmed.to_string() + }; + Err(format!("API returned status {} but response is not valid JSON: {}. Response preview: {}", + status_code, e, preview).into()) + } + } + } + + status => Err(format!( + "Unexpected API status code: {} - {}", + status.as_u16(), + status.canonical_reason().unwrap_or("Unknown") + ) + .into()), + + } +} + +/// Fetch config and run a user-provided callback to apply it. +/// The callback can update WAF rules, BPF maps, caches, etc. +pub async fn fetch_and_apply( + base_url: String, + api_key: String, + mut on_config: F, +) -> Result<(), Box> +where + F: FnMut(&ConfigApiResponse) -> Result<(), Box>, +{ + let resp = fetch_config(base_url, api_key).await?; + on_config(&resp)?; + Ok(()) +} + +/// Config worker that periodically fetches and applies configuration from API +pub struct ConfigWorker { + base_url: String, + api_key: String, + refresh_interval_secs: u64, + skels: Vec>>, + security_rules_config_path: std::path::PathBuf, + is_agent_mode: bool, + nftables_firewall: Option>>, + iptables_firewall: Option>>, +} + +impl ConfigWorker { + pub fn new(base_url: String, api_key: String, refresh_interval_secs: u64, skels: Vec>>, security_rules_config_path: std::path::PathBuf) -> Self { + Self { + base_url, + api_key, + refresh_interval_secs, + skels, + security_rules_config_path, + is_agent_mode: false, + nftables_firewall: None, + iptables_firewall: None, + } + } + + pub fn with_agent_mode(mut self, is_agent_mode: bool) -> Self { + self.is_agent_mode = is_agent_mode; + self + } + + pub fn with_nftables(mut self, nft_fw: Option>>) -> Self { + self.nftables_firewall = nft_fw; + self + } + + pub fn with_iptables(mut self, ipt_fw: Option>>) -> Self { + self.iptables_firewall = ipt_fw; + self + } +} + +/// Load configuration from a local YAML file +pub async fn load_config_from_file(path: &std::path::PathBuf) -> Result> { + use anyhow::Context; + let content = tokio::fs::read_to_string(path).await + .with_context(|| format!("Failed to read security rules file: {:?}", path))?; + let config: Config = serde_yaml::from_str(&content) + .with_context(|| format!("Failed to parse security rules YAML from file: {:?}", path))?; + Ok(ConfigApiResponse { success: true, config }) +} + +impl Worker for ConfigWorker { + fn name(&self) -> &str { + "config" + } + + fn run(&self, mut shutdown: watch::Receiver) -> tokio::task::JoinHandle<()> { + let base_url = self.base_url.clone(); + let api_key = self.api_key.clone(); + let refresh_interval_secs = self.refresh_interval_secs; + let skels = self.skels.clone(); + let security_rules_config_path = self.security_rules_config_path.clone(); + let worker_name = self.name().to_string(); + let is_agent_mode = self.is_agent_mode; + let nftables_firewall = self.nftables_firewall.clone(); + let iptables_firewall = self.iptables_firewall.clone(); + + // Create previous rules state for nftables/iptables + let nft_previous_rules: Arc>> = + Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); + let nft_previous_rules_v6: Arc>> = + Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); + let ipt_previous_rules: Arc>> = + Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); + let ipt_previous_rules_v6: Arc>> = + Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); + + tokio::spawn(async move { + // Check if running in local mode (no API key) + if api_key.is_empty() { + // LOCAL MODE: Load config from file once at startup + log::info!("[{}] No API key provided, loading config from local file...", worker_name); + match load_config_from_file(&security_rules_config_path).await { + Ok(config_response) => { + log::info!("[{}] Successfully loaded config from local file (access_rules: {}, waf_rules: {})", + worker_name, + config_response.config.access_rules.allow.ips.len() + config_response.config.access_rules.block.ips.len(), + config_response.config.waf_rules.rules.len() + ); + + // Apply rules to BPF maps (skip WAF update in agent mode) + if !skels.is_empty() { + if let Err(e) = crate::access_rules::apply_rules_from_global_with_state(&skels, is_agent_mode) { + log::error!("[{}] Failed to apply rules from local config: {}", worker_name, e); + } + } else if let Some(ref nft_fw) = nftables_firewall { + // Apply rules to nftables if BPF is not available + if let Err(e) = crate::access_rules::apply_rules_nftables(nft_fw, &nft_previous_rules, &nft_previous_rules_v6) { + log::error!("[{}] Failed to apply rules to nftables: {}", worker_name, e); + } + } else if let Some(ref ipt_fw) = iptables_firewall { + // Apply rules to iptables if BPF and nftables are not available + if let Err(e) = crate::access_rules::apply_rules_iptables(ipt_fw, &ipt_previous_rules, &ipt_previous_rules_v6) { + log::error!("[{}] Failed to apply rules to iptables: {}", worker_name, e); + } + } + } + Err(e) => { + log::error!("[{}] Failed to load config from local file: {}", worker_name, e); + } + } + + // Set up file watcher for automatic reloading + log::info!("[{}] Running in local config mode - watching {} for changes", + worker_name, security_rules_config_path.display()); + + // Set up file watcher + use notify::{Config as NotifyConfig, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; + use notify::event::ModifyKind; + use std::time::{Duration, Instant}; + use std::path::Path; + use tokio::task; + + let file_path = security_rules_config_path.clone(); + let parent_dir = Path::new(&file_path).parent().unwrap().to_path_buf(); + let (local_tx, mut local_rx) = tokio::sync::mpsc::channel::>(1); + + let _watcher_handle = task::spawn_blocking({ + let parent_dir = parent_dir.clone(); + move || { + let mut watcher = RecommendedWatcher::new( + move |res| { + let _ = local_tx.blocking_send(res); + }, + NotifyConfig::default(), + ).expect("Failed to create file watcher"); + watcher.watch(&parent_dir, RecursiveMode::NonRecursive) + .expect("Failed to watch security rules directory"); + let (_rtx, mut rrx) = tokio::sync::mpsc::channel::(1); + let _ = rrx.blocking_recv(); + } + }); + + let mut last_reload = Instant::now(); + + loop { + tokio::select! { + _ = shutdown.changed() => { + if *shutdown.borrow() { + break; + } + } + event = local_rx.recv() => { + if let Some(event) = event { + match event { + Ok(e) => { + match e.kind { + EventKind::Modify(ModifyKind::Data(_)) | EventKind::Create(..) | EventKind::Remove(..) => { + // Check if the modified file is the security_rules file + let is_target_file = e.paths.iter().any(|p| { + p.to_str().map_or(false, |s| { + s.ends_with("yaml") && + p.file_name() == file_path.file_name() + }) + }); + + if is_target_file && last_reload.elapsed() > Duration::from_secs(2) { + last_reload = Instant::now(); + log::info!("[{}] Security rules file changed, reloading...", worker_name); + + match load_config_from_file(&file_path).await { + Ok(config_response) => { + log::info!("[{}] Config reloaded successfully (access_rules: {}, waf_rules: {})", + worker_name, + config_response.config.access_rules.allow.ips.len() + config_response.config.access_rules.block.ips.len(), + config_response.config.waf_rules.rules.len() + ); + + // Apply rules to BPF maps (skip WAF update in agent mode) + if !skels.is_empty() { + if let Err(e) = crate::access_rules::apply_rules_from_global_with_state(&skels, is_agent_mode) { + log::error!("[{}] Failed to apply rules from reloaded config: {}", worker_name, e); + } + } else if let Some(ref nft_fw) = nftables_firewall { + // Apply rules to nftables if BPF is not available + if let Err(e) = crate::access_rules::apply_rules_nftables(nft_fw, &nft_previous_rules, &nft_previous_rules_v6) { + log::error!("[{}] Failed to apply rules to nftables: {}", worker_name, e); + } + } else if let Some(ref ipt_fw) = iptables_firewall { + // Apply rules to iptables if BPF and nftables are not available + if let Err(e) = crate::access_rules::apply_rules_iptables(ipt_fw, &ipt_previous_rules, &ipt_previous_rules_v6) { + log::error!("[{}] Failed to apply rules to iptables: {}", worker_name, e); + } + } + } + Err(e) => { + log::error!("[{}] Failed to reload config from local file: {}", worker_name, e); + } + } + } + } + _ => {} + } + } + Err(e) => { + log::error!("[{}] File watch error: {:?}", worker_name, e); + } + } + } + } + } + } + } else { + // API MODE: Fetch from API periodically + match fetch_config(base_url.clone(), api_key.clone()).await { + Ok(_config_response) => { + // Apply rules to BPF maps after fetching config (skip WAF update in agent mode) + if !skels.is_empty() { + if let Err(e) = crate::access_rules::apply_rules_from_global_with_state(&skels, is_agent_mode) { + log::error!("[{}] Failed to apply rules from initial config: {}", worker_name, e); + } + } else if let Some(ref nft_fw) = nftables_firewall { + // Apply rules to nftables if BPF is not available + if let Err(e) = crate::access_rules::apply_rules_nftables(nft_fw, &nft_previous_rules, &nft_previous_rules_v6) { + log::error!("[{}] Failed to apply rules to nftables: {}", worker_name, e); + } + } else if let Some(ref ipt_fw) = iptables_firewall { + // Apply rules to iptables if BPF and nftables are not available + if let Err(e) = crate::access_rules::apply_rules_iptables(ipt_fw, &ipt_previous_rules, &ipt_previous_rules_v6) { + log::error!("[{}] Failed to apply rules to iptables: {}", worker_name, e); + } + } + } + Err(e) => { + log::warn!("[{}] Failed to fetch initial config from API: {}", worker_name, e); + log::warn!("[{}] Will retry on next scheduled interval", worker_name); + } + } + + // Set up periodic refresh interval + let mut interval = interval(Duration::from_secs(refresh_interval_secs)); + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = shutdown.changed() => { + if *shutdown.borrow() { + break; + } + } + _ = interval.tick() => { + log::debug!("[{}] Periodic config refresh triggered", worker_name); + match fetch_config(base_url.clone(), api_key.clone()).await { + Ok(config_response) => { + log::debug!("[{}] Config refreshed successfully (waf_rules: {}, access_rules: {})", + worker_name, + config_response.config.waf_rules.rules.len(), + config_response.config.access_rules.allow.ips.len() + config_response.config.access_rules.block.ips.len() + ); + + // Apply rules to BPF maps after fetching config (skip WAF update in agent mode) + if !skels.is_empty() { + if let Err(e) = crate::access_rules::apply_rules_from_global_with_state(&skels, is_agent_mode) { + log::error!("[{}] Failed to apply rules from config: {}", worker_name, e); + } + } else if let Some(ref nft_fw) = nftables_firewall { + // Apply rules to nftables if BPF is not available + if let Err(e) = crate::access_rules::apply_rules_nftables(nft_fw, &nft_previous_rules, &nft_previous_rules_v6) { + log::error!("[{}] Failed to apply rules to nftables: {}", worker_name, e); + } + } else if let Some(ref ipt_fw) = iptables_firewall { + // Apply rules to iptables if BPF and nftables are not available + if let Err(e) = crate::access_rules::apply_rules_iptables(ipt_fw, &ipt_previous_rules, &ipt_previous_rules_v6) { + log::error!("[{}] Failed to apply rules to iptables: {}", worker_name, e); + } + } + } + Err(e) => { + log::warn!("[{}] Failed to fetch config from API: {}", worker_name, e); + } + } + } + } + } + } + + log::info!("[{}] Config fetcher task stopped", worker_name); + }) + } +} diff --git a/src/worker/geoip_mmdb.rs b/src/worker/geoip_mmdb.rs index d85f7bf..7060e6b 100644 --- a/src/worker/geoip_mmdb.rs +++ b/src/worker/geoip_mmdb.rs @@ -1,329 +1,329 @@ -use std::collections::HashMap; -use std::path::PathBuf; - -use anyhow::{anyhow, Context, Result}; -use tokio::sync::watch; -use tokio::task::JoinHandle; -use tokio::time::{interval, Duration, MissedTickBehavior}; - -use crate::http_client::get_global_reqwest_client; -use crate::worker::Worker; - -#[derive(Clone)] -pub enum GeoipDatabaseType { - Country, - Asn, - City, -} - -pub struct GeoipMmdbWorker { - interval_secs: u64, - mmdb_base_url: String, - versions_url: String, - mmdb_path: Option, - headers: Option>, - db_type: GeoipDatabaseType, -} - -impl GeoipMmdbWorker { - pub fn new( - interval_secs: u64, - mmdb_base_url: String, - versions_url: String, - mmdb_path: Option, - headers: Option>, - db_type: GeoipDatabaseType, - ) -> Self { - Self { - interval_secs, - mmdb_base_url, - versions_url, - mmdb_path, - headers, - db_type, - } - } -} - -impl Worker for GeoipMmdbWorker { - fn name(&self) -> &str { - match self.db_type { - GeoipDatabaseType::Country => "geoip_country_mmdb", - GeoipDatabaseType::Asn => "geoip_asn_mmdb", - GeoipDatabaseType::City => "geoip_city_mmdb", - } - } - - fn run(&self, mut shutdown: watch::Receiver) -> JoinHandle<()> { - let interval_secs = self.interval_secs; - let mmdb_base_url = self.mmdb_base_url.clone(); - let versions_url = self.versions_url.clone(); - let mmdb_path = self.mmdb_path.clone(); - let headers = self.headers.clone(); - let db_type_enum = self.db_type.clone(); - let db_type = match self.db_type { - GeoipDatabaseType::Country => "country", - GeoipDatabaseType::Asn => "asn", - GeoipDatabaseType::City => "city", - }; - - tokio::spawn(async move { - let worker_name = format!("geoip_{}_mmdb", db_type); - let mut current_version: Option = None; - - log::info!( - "[{}] Starting GeoIP MMDB worker (interval: {}s)", - worker_name, - interval_secs - ); - - match sync_mmdb( - &mmdb_base_url, - &versions_url, - mmdb_path.clone(), - current_version.clone(), - headers.clone(), - db_type_enum.clone(), - ) - .await - { - Ok(new_ver) => { - // Only refresh if the file was actually updated (not "existing") - if new_ver.as_ref().map(|v| v != "existing").unwrap_or(true) { - current_version = new_ver; - } - // Always refresh the database reader to ensure it's loaded - match db_type { - "country" => { - if let Err(e) = crate::threat::refresh_geoip_country_mmdb().await { - log::warn!("[{}] Failed to refresh GeoIP Country MMDB: {}", worker_name, e); - } - } - "asn" => { - if let Err(e) = crate::threat::refresh_geoip_asn_mmdb().await { - log::warn!("[{}] Failed to refresh GeoIP ASN MMDB: {}", worker_name, e); - } - } - "city" => { - if let Err(e) = crate::threat::refresh_geoip_city_mmdb().await { - log::warn!("[{}] Failed to refresh GeoIP City MMDB: {}", worker_name, e); - } - } - _ => {} - } - } - Err(e) => { - log::warn!("[{}] Initial GeoIP {} MMDB sync failed: {}", worker_name, db_type, e); - } - } - - let mut ticker = interval(Duration::from_secs(interval_secs)); - ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = ticker.tick() => { - match sync_mmdb( - &mmdb_base_url, - &versions_url, - mmdb_path.clone(), - current_version.clone(), - headers.clone(), - db_type_enum.clone(), - ).await { - Ok(new_ver) => { - if new_ver != current_version { - current_version = new_ver; - log::info!("[{}] GeoIP {} MMDB updated to latest version", worker_name, db_type); - // Refresh only the specific database that was updated - match db_type { - "country" => { - if let Err(e) = crate::threat::refresh_geoip_country_mmdb().await { - log::warn!("[{}] Failed to refresh GeoIP Country MMDB: {}", worker_name, e); - } - } - "asn" => { - if let Err(e) = crate::threat::refresh_geoip_asn_mmdb().await { - log::warn!("[{}] Failed to refresh GeoIP ASN MMDB: {}", worker_name, e); - } - } - "city" => { - if let Err(e) = crate::threat::refresh_geoip_city_mmdb().await { - log::warn!("[{}] Failed to refresh GeoIP City MMDB: {}", worker_name, e); - } - } - _ => {} - } - } else { - log::debug!("[{}] GeoIP {} MMDB already at latest version", worker_name, db_type); - } - } - Err(e) => { - log::warn!("[{}] Periodic GeoIP MMDB sync failed: {}", worker_name, e); - } - } - } - _ = shutdown.changed() => { - if *shutdown.borrow() { - log::info!("[{}] GeoIP MMDB worker received shutdown signal", worker_name); - break; - } - } - } - } - }) - } -} - -async fn sync_mmdb( - mmdb_base_url: &str, - versions_url: &str, - mmdb_path: Option, - current_version: Option, - headers: Option>, - db_type: GeoipDatabaseType, -) -> Result> { - let mut local_path = mmdb_path.ok_or_else(|| anyhow!("MMDB path not configured for GeoIP MMDB worker"))?; - - // If the path doesn't have a file extension, treat it as a directory and append the filename - if local_path.extension().is_none() { - let filename = match db_type { - GeoipDatabaseType::Country => "GeoLite2-Country.mmdb", - GeoipDatabaseType::Asn => "GeoLite2-ASN.mmdb", - GeoipDatabaseType::City => "GeoLite2-City.mmdb", - }; - local_path = local_path.join(filename); - } - - if versions_url.is_empty() { - // Check if file already exists - if it does, skip download to avoid redundant writes - // This prevents multiple workers from downloading the same file simultaneously - if local_path.exists() { - log::debug!("GeoIP MMDB file already exists at {:?}, skipping download", local_path); - return Ok(Some("existing".to_string())); - } - - log::debug!("No versions URL provided, downloading directly from {}", mmdb_base_url); - - let bytes = if mmdb_base_url.starts_with("http://") || mmdb_base_url.starts_with("https://") { - download_mmdb(mmdb_base_url, headers.as_ref()).await? - } else { - let src_path = PathBuf::from(mmdb_base_url); - tokio::fs::read(&src_path) - .await - .with_context(|| format!("Failed to read MMDB from {:?}", src_path))? - }; - - if let Some(parent) = local_path.parent() { - tokio::fs::create_dir_all(parent) - .await - .with_context(|| format!("Failed to create MMDB directory {:?}", parent))?; - } - - tokio::fs::write(&local_path, &bytes) - .await - .with_context(|| format!("Failed to write MMDB to {:?}", local_path))?; - - log::info!("GeoIP MMDB written to {:?}", local_path); - return Ok(Some("direct".to_string())); - } - - let versions_text = if versions_url.starts_with("http://") || versions_url.starts_with("https://") { - let client = get_global_reqwest_client() - .context("Failed to get HTTP client for versions download")?; - let mut req = client.get(versions_url); - if let Some(ref hdrs) = headers { - for (key, value) in hdrs { - req = req.header(key, value); - } - } - let resp = req - .send() - .await - .with_context(|| format!("Failed to download versions file from {}", versions_url))?; - let status = resp.status(); - if !status.is_success() { - return Err(anyhow!( - "Versions download failed: status {} from {}", - status, - versions_url - )); - } - resp.text().await.context("Failed to read versions body")? - } else { - tokio::fs::read_to_string(versions_url) - .await - .with_context(|| format!("Failed to read versions file from {}", versions_url))? - }; - - let latest = parse_latest_version(&versions_text) - .ok_or_else(|| anyhow!("Failed to parse latest version from versions file"))?; - - if let Some(ref curr) = current_version { - if curr == &latest { - log::debug!("GeoIP MMDB already at latest version: {}", latest); - return Ok(current_version); - } - } - - let base = mmdb_base_url.trim_end_matches('/'); - let bytes = if base.starts_with("http://") || base.starts_with("https://") { - let url = format!("{}/{}", base, latest); - download_mmdb(&url, headers.as_ref()).await? - } else { - let src_path = PathBuf::from(base).join(&latest); - tokio::fs::read(&src_path) - .await - .with_context(|| format!("Failed to read MMDB from {:?}", src_path))? - }; - - if let Some(parent) = local_path.parent() { - tokio::fs::create_dir_all(parent) - .await - .with_context(|| format!("Failed to create MMDB directory {:?}", parent))?; - } - - tokio::fs::write(&local_path, &bytes) - .await - .with_context(|| format!("Failed to write MMDB to {:?}", local_path))?; - - log::info!("GeoIP MMDB written to {:?}", local_path); - crate::threat::refresh_geoip_mmdb().await?; - Ok(Some(latest)) -} - -fn parse_latest_version(text: &str) -> Option { - for line in text.lines() { - if let Some(rest) = line.strip_prefix("latest=") { - return Some(rest.trim().to_string()); - } - } - None -} - -async fn download_mmdb(url: &str, headers: Option<&HashMap>) -> Result> { - let client = get_global_reqwest_client() - .context("Failed to get HTTP client for MMDB download")?; - let mut req = client.get(url); - if let Some(hdrs) = headers { - for (key, value) in hdrs { - req = req.header(key, value); - } - } - let resp = req - .send() - .await - .with_context(|| format!("Failed to download MMDB from {}", url))?; - let status = resp.status(); - if !status.is_success() { - return Err(anyhow!( - "MMDB download failed: status {} from {}", - status, - url - )); - } - let bytes = resp.bytes().await.context("Failed to read MMDB body")?; - Ok(bytes.to_vec()) -} - - +use std::collections::HashMap; +use std::path::PathBuf; + +use anyhow::{anyhow, Context, Result}; +use tokio::sync::watch; +use tokio::task::JoinHandle; +use tokio::time::{interval, Duration, MissedTickBehavior}; + +use crate::http_client::get_global_reqwest_client; +use crate::worker::Worker; + +#[derive(Clone)] +pub enum GeoipDatabaseType { + Country, + Asn, + City, +} + +pub struct GeoipMmdbWorker { + interval_secs: u64, + mmdb_base_url: String, + versions_url: String, + mmdb_path: Option, + headers: Option>, + db_type: GeoipDatabaseType, +} + +impl GeoipMmdbWorker { + pub fn new( + interval_secs: u64, + mmdb_base_url: String, + versions_url: String, + mmdb_path: Option, + headers: Option>, + db_type: GeoipDatabaseType, + ) -> Self { + Self { + interval_secs, + mmdb_base_url, + versions_url, + mmdb_path, + headers, + db_type, + } + } +} + +impl Worker for GeoipMmdbWorker { + fn name(&self) -> &str { + match self.db_type { + GeoipDatabaseType::Country => "geoip_country_mmdb", + GeoipDatabaseType::Asn => "geoip_asn_mmdb", + GeoipDatabaseType::City => "geoip_city_mmdb", + } + } + + fn run(&self, mut shutdown: watch::Receiver) -> JoinHandle<()> { + let interval_secs = self.interval_secs; + let mmdb_base_url = self.mmdb_base_url.clone(); + let versions_url = self.versions_url.clone(); + let mmdb_path = self.mmdb_path.clone(); + let headers = self.headers.clone(); + let db_type_enum = self.db_type.clone(); + let db_type = match self.db_type { + GeoipDatabaseType::Country => "country", + GeoipDatabaseType::Asn => "asn", + GeoipDatabaseType::City => "city", + }; + + tokio::spawn(async move { + let worker_name = format!("geoip_{}_mmdb", db_type); + let mut current_version: Option = None; + + log::info!( + "[{}] Starting GeoIP MMDB worker (interval: {}s)", + worker_name, + interval_secs + ); + + match sync_mmdb( + &mmdb_base_url, + &versions_url, + mmdb_path.clone(), + current_version.clone(), + headers.clone(), + db_type_enum.clone(), + ) + .await + { + Ok(new_ver) => { + // Only refresh if the file was actually updated (not "existing") + if new_ver.as_ref().map(|v| v != "existing").unwrap_or(true) { + current_version = new_ver; + } + // Always refresh the database reader to ensure it's loaded + match db_type { + "country" => { + if let Err(e) = crate::threat::refresh_geoip_country_mmdb().await { + log::warn!("[{}] Failed to refresh GeoIP Country MMDB: {}", worker_name, e); + } + } + "asn" => { + if let Err(e) = crate::threat::refresh_geoip_asn_mmdb().await { + log::warn!("[{}] Failed to refresh GeoIP ASN MMDB: {}", worker_name, e); + } + } + "city" => { + if let Err(e) = crate::threat::refresh_geoip_city_mmdb().await { + log::warn!("[{}] Failed to refresh GeoIP City MMDB: {}", worker_name, e); + } + } + _ => {} + } + } + Err(e) => { + log::warn!("[{}] Initial GeoIP {} MMDB sync failed: {}", worker_name, db_type, e); + } + } + + let mut ticker = interval(Duration::from_secs(interval_secs)); + ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = ticker.tick() => { + match sync_mmdb( + &mmdb_base_url, + &versions_url, + mmdb_path.clone(), + current_version.clone(), + headers.clone(), + db_type_enum.clone(), + ).await { + Ok(new_ver) => { + if new_ver != current_version { + current_version = new_ver; + log::info!("[{}] GeoIP {} MMDB updated to latest version", worker_name, db_type); + // Refresh only the specific database that was updated + match db_type { + "country" => { + if let Err(e) = crate::threat::refresh_geoip_country_mmdb().await { + log::warn!("[{}] Failed to refresh GeoIP Country MMDB: {}", worker_name, e); + } + } + "asn" => { + if let Err(e) = crate::threat::refresh_geoip_asn_mmdb().await { + log::warn!("[{}] Failed to refresh GeoIP ASN MMDB: {}", worker_name, e); + } + } + "city" => { + if let Err(e) = crate::threat::refresh_geoip_city_mmdb().await { + log::warn!("[{}] Failed to refresh GeoIP City MMDB: {}", worker_name, e); + } + } + _ => {} + } + } else { + log::debug!("[{}] GeoIP {} MMDB already at latest version", worker_name, db_type); + } + } + Err(e) => { + log::warn!("[{}] Periodic GeoIP MMDB sync failed: {}", worker_name, e); + } + } + } + _ = shutdown.changed() => { + if *shutdown.borrow() { + log::info!("[{}] GeoIP MMDB worker received shutdown signal", worker_name); + break; + } + } + } + } + }) + } +} + +async fn sync_mmdb( + mmdb_base_url: &str, + versions_url: &str, + mmdb_path: Option, + current_version: Option, + headers: Option>, + db_type: GeoipDatabaseType, +) -> Result> { + let mut local_path = mmdb_path.ok_or_else(|| anyhow!("MMDB path not configured for GeoIP MMDB worker"))?; + + // If the path doesn't have a file extension, treat it as a directory and append the filename + if local_path.extension().is_none() { + let filename = match db_type { + GeoipDatabaseType::Country => "GeoLite2-Country.mmdb", + GeoipDatabaseType::Asn => "GeoLite2-ASN.mmdb", + GeoipDatabaseType::City => "GeoLite2-City.mmdb", + }; + local_path = local_path.join(filename); + } + + if versions_url.is_empty() { + // Check if file already exists - if it does, skip download to avoid redundant writes + // This prevents multiple workers from downloading the same file simultaneously + if local_path.exists() { + log::debug!("GeoIP MMDB file already exists at {:?}, skipping download", local_path); + return Ok(Some("existing".to_string())); + } + + log::debug!("No versions URL provided, downloading directly from {}", mmdb_base_url); + + let bytes = if mmdb_base_url.starts_with("http://") || mmdb_base_url.starts_with("https://") { + download_mmdb(mmdb_base_url, headers.as_ref()).await? + } else { + let src_path = PathBuf::from(mmdb_base_url); + tokio::fs::read(&src_path) + .await + .with_context(|| format!("Failed to read MMDB from {:?}", src_path))? + }; + + if let Some(parent) = local_path.parent() { + tokio::fs::create_dir_all(parent) + .await + .with_context(|| format!("Failed to create MMDB directory {:?}", parent))?; + } + + tokio::fs::write(&local_path, &bytes) + .await + .with_context(|| format!("Failed to write MMDB to {:?}", local_path))?; + + log::info!("GeoIP MMDB written to {:?}", local_path); + return Ok(Some("direct".to_string())); + } + + let versions_text = if versions_url.starts_with("http://") || versions_url.starts_with("https://") { + let client = get_global_reqwest_client() + .context("Failed to get HTTP client for versions download")?; + let mut req = client.get(versions_url); + if let Some(ref hdrs) = headers { + for (key, value) in hdrs { + req = req.header(key, value); + } + } + let resp = req + .send() + .await + .with_context(|| format!("Failed to download versions file from {}", versions_url))?; + let status = resp.status(); + if !status.is_success() { + return Err(anyhow!( + "Versions download failed: status {} from {}", + status, + versions_url + )); + } + resp.text().await.context("Failed to read versions body")? + } else { + tokio::fs::read_to_string(versions_url) + .await + .with_context(|| format!("Failed to read versions file from {}", versions_url))? + }; + + let latest = parse_latest_version(&versions_text) + .ok_or_else(|| anyhow!("Failed to parse latest version from versions file"))?; + + if let Some(ref curr) = current_version { + if curr == &latest { + log::debug!("GeoIP MMDB already at latest version: {}", latest); + return Ok(current_version); + } + } + + let base = mmdb_base_url.trim_end_matches('/'); + let bytes = if base.starts_with("http://") || base.starts_with("https://") { + let url = format!("{}/{}", base, latest); + download_mmdb(&url, headers.as_ref()).await? + } else { + let src_path = PathBuf::from(base).join(&latest); + tokio::fs::read(&src_path) + .await + .with_context(|| format!("Failed to read MMDB from {:?}", src_path))? + }; + + if let Some(parent) = local_path.parent() { + tokio::fs::create_dir_all(parent) + .await + .with_context(|| format!("Failed to create MMDB directory {:?}", parent))?; + } + + tokio::fs::write(&local_path, &bytes) + .await + .with_context(|| format!("Failed to write MMDB to {:?}", local_path))?; + + log::info!("GeoIP MMDB written to {:?}", local_path); + crate::threat::refresh_geoip_mmdb().await?; + Ok(Some(latest)) +} + +fn parse_latest_version(text: &str) -> Option { + for line in text.lines() { + if let Some(rest) = line.strip_prefix("latest=") { + return Some(rest.trim().to_string()); + } + } + None +} + +async fn download_mmdb(url: &str, headers: Option<&HashMap>) -> Result> { + let client = get_global_reqwest_client() + .context("Failed to get HTTP client for MMDB download")?; + let mut req = client.get(url); + if let Some(hdrs) = headers { + for (key, value) in hdrs { + req = req.header(key, value); + } + } + let resp = req + .send() + .await + .with_context(|| format!("Failed to download MMDB from {}", url))?; + let status = resp.status(); + if !status.is_success() { + return Err(anyhow!( + "MMDB download failed: status {} from {}", + status, + url + )); + } + let bytes = resp.bytes().await.context("Failed to read MMDB body")?; + Ok(bytes.to_vec()) +} + + diff --git a/src/worker/log.rs b/src/worker/log.rs index 745303a..652dd33 100644 --- a/src/worker/log.rs +++ b/src/worker/log.rs @@ -1,430 +1,430 @@ -use std::sync::Arc; -use std::sync::RwLock; -use std::time::{Duration, Instant}; -use tokio::sync::{mpsc, watch}; -use tokio::time::interval; -use serde::{Deserialize, Serialize}; -use chrono::{DateTime, Utc}; - -use crate::http_client; - -/// Maximum batch size allowed by the API server -const API_MAX_BATCH_SIZE: usize = 1000; - -/// Maximum number of failed events to store before dropping oldest -/// This prevents unbounded memory growth when the API is unreachable -const MAX_FAILED_EVENTS: usize = 5000; - -/// Configuration for sending access logs to arxignis server -#[derive(Debug, Clone)] -pub struct LogSenderConfig { - pub enabled: bool, - pub base_url: String, - pub api_key: String, - pub batch_size_limit: usize, // Maximum number of logs in a batch - pub batch_size_bytes: usize, // Maximum size of batch in bytes (5MB) - pub batch_timeout_secs: u64, // Maximum time to wait before sending batch (10 seconds) - pub include_request_body: bool, // Whether to include request body in logs - pub max_body_size: usize, // Maximum size for request body (1MB) -} - -impl LogSenderConfig { - pub fn new(enabled: bool, base_url: String, api_key: String) -> Self { - Self { - enabled, - base_url, - api_key, - batch_size_limit: 1000, // Default: 1000 logs per batch (API limit) - batch_size_bytes: 5 * 1024 * 1024, // Default: 5MB - batch_timeout_secs: 10, // Default: 10 seconds - include_request_body: false, // Default: disabled - max_body_size: 1024 * 1024, // Default: 1MB - } - } - - /// Check if log sending is enabled and api_key is configured - pub fn should_send_logs(&self) -> bool { - self.enabled && !self.api_key.is_empty() - } -} - -/// Global log sender configuration -static LOG_SENDER_CONFIG: std::sync::OnceLock>>> = std::sync::OnceLock::new(); - -pub fn get_log_sender_config() -> Arc>> { - LOG_SENDER_CONFIG - .get_or_init(|| Arc::new(RwLock::new(None))) - .clone() -} - -pub fn set_log_sender_config(config: LogSenderConfig) { - let store = get_log_sender_config(); - if let Ok(mut guard) = store.write() { - *guard = Some(config); - } -} - -/// Unified event types that can be sent to the /events endpoint +use std::sync::Arc; +use std::sync::RwLock; +use std::time::{Duration, Instant}; +use tokio::sync::{mpsc, watch}; +use tokio::time::interval; +use serde::{Deserialize, Serialize}; +use chrono::{DateTime, Utc}; + +use crate::http_client; + +/// Maximum batch size allowed by the API server +const API_MAX_BATCH_SIZE: usize = 1000; + +/// Maximum number of failed events to store before dropping oldest +/// This prevents unbounded memory growth when the API is unreachable +const MAX_FAILED_EVENTS: usize = 5000; + +/// Configuration for sending access logs to arxignis server +#[derive(Debug, Clone)] +pub struct LogSenderConfig { + pub enabled: bool, + pub base_url: String, + pub api_key: String, + pub batch_size_limit: usize, // Maximum number of logs in a batch + pub batch_size_bytes: usize, // Maximum size of batch in bytes (5MB) + pub batch_timeout_secs: u64, // Maximum time to wait before sending batch (10 seconds) + pub include_request_body: bool, // Whether to include request body in logs + pub max_body_size: usize, // Maximum size for request body (1MB) +} + +impl LogSenderConfig { + pub fn new(enabled: bool, base_url: String, api_key: String) -> Self { + Self { + enabled, + base_url, + api_key, + batch_size_limit: 1000, // Default: 1000 logs per batch (API limit) + batch_size_bytes: 5 * 1024 * 1024, // Default: 5MB + batch_timeout_secs: 10, // Default: 10 seconds + include_request_body: false, // Default: disabled + max_body_size: 1024 * 1024, // Default: 1MB + } + } + + /// Check if log sending is enabled and api_key is configured + pub fn should_send_logs(&self) -> bool { + self.enabled && !self.api_key.is_empty() + } +} + +/// Global log sender configuration +static LOG_SENDER_CONFIG: std::sync::OnceLock>>> = std::sync::OnceLock::new(); + +pub fn get_log_sender_config() -> Arc>> { + LOG_SENDER_CONFIG + .get_or_init(|| Arc::new(RwLock::new(None))) + .clone() +} + +pub fn set_log_sender_config(config: LogSenderConfig) { + let store = get_log_sender_config(); + if let Ok(mut guard) = store.write() { + *guard = Some(config); + } +} + +/// Unified event types that can be sent to the /events endpoint #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "event_type")] pub enum UnifiedEvent { #[serde(rename = "http_access_log")] HttpAccessLog(crate::access_log::HttpAccessLog), - #[serde(rename = "dropped_ip")] + #[serde(rename = "dropped_ips")] DroppedIp(crate::bpf_stats::DroppedIpEvent), #[serde(rename = "tcp_fingerprint")] TcpFingerprint(crate::utils::tcp_fingerprint::TcpFingerprintEvent), #[serde(rename = "agent_status")] AgentStatus(crate::agent_status::AgentStatusEvent), -} - -impl UnifiedEvent { - /// Get the event type as a string +} + +impl UnifiedEvent { + /// Get the event type as a string pub fn event_type(&self) -> &'static str { match self { UnifiedEvent::HttpAccessLog(_) => "http_access_log", - UnifiedEvent::DroppedIp(_) => "dropped_ip", + UnifiedEvent::DroppedIp(_) => "dropped_ips", UnifiedEvent::TcpFingerprint(_) => "tcp_fingerprint", UnifiedEvent::AgentStatus(_) => "agent_status", } } - - /// Get the timestamp of the event - pub fn timestamp(&self) -> DateTime { - match self { - UnifiedEvent::HttpAccessLog(event) => event.timestamp, - UnifiedEvent::DroppedIp(event) => event.timestamp, - UnifiedEvent::TcpFingerprint(event) => event.timestamp, - UnifiedEvent::AgentStatus(event) => event.timestamp, - } - } - - /// Convert to JSON string - pub fn to_json(&self) -> Result { - serde_json::to_string(self) - } -} - -/// Buffer for storing events before batch sending -#[derive(Debug)] -struct EventBuffer { - events: Vec, - failed_events: Vec, // Store events that failed to send - total_size_bytes: usize, - failed_size_bytes: usize, - last_flush_time: Instant, - last_retry_time: Instant, // Track when we last tried to resend failed events -} - -impl EventBuffer { - fn new() -> Self { - Self { - events: Vec::new(), - failed_events: Vec::new(), - total_size_bytes: 0, - failed_size_bytes: 0, - last_flush_time: Instant::now(), - last_retry_time: Instant::now(), - } - } - - fn add_event(&mut self, event: UnifiedEvent) -> usize { - // Estimate event size (rough approximation) - let event_size = estimate_event_size(&event); - self.events.push(event); - self.total_size_bytes += event_size; - self.events.len() - } - - fn should_flush(&self, config: &LogSenderConfig) -> bool { - self.events.len() >= config.batch_size_limit || - self.total_size_bytes >= config.batch_size_bytes || - self.last_flush_time.elapsed().as_secs() >= config.batch_timeout_secs - } - - fn take_events(&mut self) -> Vec { - self.total_size_bytes = 0; - self.last_flush_time = Instant::now(); - std::mem::take(&mut self.events) - } - - fn is_empty(&self) -> bool { - self.events.is_empty() - } - - fn add_failed_events(&mut self, events: Vec) { - // Calculate how many events we can add without exceeding the limit - let current_count = self.failed_events.len(); - let available_capacity = MAX_FAILED_EVENTS.saturating_sub(current_count); - - if available_capacity == 0 { - // Buffer full, drop oldest events to make room - let events_to_drop = events.len().min(self.failed_events.len()); - for _ in 0..events_to_drop { - if let Some(dropped) = self.failed_events.first() { - self.failed_size_bytes = self.failed_size_bytes.saturating_sub(estimate_event_size(dropped)); - } - self.failed_events.remove(0); - } - log::warn!("Failed events buffer full, dropped {} oldest events", events_to_drop); - } - - // Add new events (up to the limit) - let events_to_add = events.into_iter().take(MAX_FAILED_EVENTS.saturating_sub(self.failed_events.len())); - for event in events_to_add { - let event_size = estimate_event_size(&event); - self.failed_events.push(event); - self.failed_size_bytes += event_size; - } - } - - fn should_retry_failed_events(&self) -> bool { - // Retry failed events every 30 seconds - !self.failed_events.is_empty() && - self.last_retry_time.elapsed().as_secs() >= 30 - } - - fn take_failed_events(&mut self) -> Vec { - self.failed_size_bytes = 0; - self.last_retry_time = Instant::now(); - std::mem::take(&mut self.failed_events) - } - - fn has_failed_events(&self) -> bool { - !self.failed_events.is_empty() - } -} - -/// Estimate the size of an event in bytes -fn estimate_event_size(event: &UnifiedEvent) -> usize { - // Rough estimation based on JSON serialization - // This is an approximation - actual size may vary - let base_size = 500; // Base overhead - - match event { - UnifiedEvent::HttpAccessLog(log) => { - base_size + log.http.body.len() + log.response.body.len() + - log.http.headers.len() * 50 // Rough estimate for headers - } - UnifiedEvent::DroppedIp(_) => base_size + 200, // Dropped IP events are relatively small - UnifiedEvent::TcpFingerprint(_) => base_size + 100, // TCP fingerprint events are small - UnifiedEvent::AgentStatus(_) => base_size + 300, // Agent status events are small - } -} - -/// Channel for sending events to the batch processor -static EVENT_CHANNEL: std::sync::OnceLock> = std::sync::OnceLock::new(); - -pub fn get_event_channel() -> Option<&'static mpsc::UnboundedSender> { - EVENT_CHANNEL.get() -} - -pub fn set_event_channel(sender: mpsc::UnboundedSender) { - let _ = EVENT_CHANNEL.set(sender); -} - -/// Send an event to the unified queue -pub fn send_event(event: UnifiedEvent) { - if let Some(sender) = get_event_channel() { - if let Err(e) = sender.send(event) { - log::warn!("Failed to send event to queue: {}", e); - } - } else { - // Event channel not initialized - this is expected when log_sending_enabled is false - log::trace!("Event channel not initialized, skipping event queuing"); - } -} - -/// Send a batch of events to the /events endpoint -/// Automatically splits large batches into chunks of API_MAX_BATCH_SIZE -async fn send_event_batch(events: Vec) -> Result<(), Box> { - if events.is_empty() { - return Ok(()); - } - - let config = { - let config_store = get_log_sender_config(); - let config_guard = config_store.read().unwrap(); - config_guard.as_ref().cloned() - }; - - let config = match config { - Some(config) => { - if !config.should_send_logs() { - return Ok(()); - } - config - } - None => return Ok(()), - }; - - // Use shared HTTP client with keepalive instead of creating new client - let client = http_client::get_global_reqwest_client() - .map_err(|e| format!("Failed to get global HTTP client: {}", e))?; - - let url = format!("{}/events", config.base_url); - - // Split events into chunks of API_MAX_BATCH_SIZE to respect API limits - let chunks: Vec<_> = events.chunks(API_MAX_BATCH_SIZE).collect(); - let total_events = events.len(); - - for (chunk_idx, chunk) in chunks.iter().enumerate() { - let json = serde_json::to_string(chunk)?; - - log::debug!("Sending chunk {}/{} ({} events) to {}", - chunk_idx + 1, chunks.len(), chunk.len(), url); - - let response = client - .post(&url) - .header("Authorization", format!("Bearer {}", config.api_key)) - .header("Content-Type", "application/json") - .body(json) - .send() - .await?; - - if !response.status().is_success() { - let status = response.status(); - let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); - log::warn!("Failed to send event batch chunk {}/{} to /events endpoint: {} - {} (chunk size: {}, total batch: {})", - chunk_idx + 1, chunks.len(), status, error_text, chunk.len(), total_events); - return Err(format!("HTTP {}: {}", status, error_text).into()); - } else { - log::debug!("Successfully sent event batch chunk {}/{} to /events endpoint (chunk size: {})", - chunk_idx + 1, chunks.len(), chunk.len()); - } - } - - log::debug!("Successfully sent all {} events in {} chunk(s) to /events endpoint", total_events, chunks.len()); - Ok(()) -} - -/// Log sender worker that batches and sends events -pub struct LogSenderWorker { - check_interval_secs: u64, -} - -impl LogSenderWorker { - pub fn new(check_interval_secs: u64) -> Self { - Self { - check_interval_secs, - } - } -} - -impl super::Worker for LogSenderWorker { - fn name(&self) -> &str { - "log_sender" - } - - fn run(&self, mut shutdown: watch::Receiver) -> tokio::task::JoinHandle<()> { - let check_interval_secs = self.check_interval_secs; - let worker_name = self.name().to_string(); - - // Create event channel - let (sender, mut receiver) = mpsc::unbounded_channel::(); - set_event_channel(sender); - - tokio::spawn(async move { - log::info!("[{}] Starting log sender worker", worker_name); - - let mut buffer = EventBuffer::new(); - let mut flush_interval = interval(Duration::from_secs(check_interval_secs)); - flush_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - - loop { - tokio::select! { - _ = shutdown.changed() => { - if *shutdown.borrow() { - log::info!("[{}] Shutdown signal received, flushing remaining events", worker_name); - - // Flush any remaining events before exiting - if !buffer.is_empty() { - let events = buffer.take_events(); - if let Err(e) = send_event_batch(events.clone()).await { - log::warn!("[{}] Failed to send final event batch: {}, storing locally", worker_name, e); - buffer.add_failed_events(events); - } - } - - // Also try to flush any remaining failed events - if buffer.has_failed_events() { - let failed_events = buffer.take_failed_events(); - log::warn!("[{}] Storing {} failed events locally (endpoint unavailable)", worker_name, failed_events.len()); - } - - log::info!("[{}] Log sender worker stopped", worker_name); - break; - } - } - - // Receive new events - event = receiver.recv() => { - match event { - Some(event) => { - let count = buffer.add_event(event); - log::trace!("[{}] Added event to buffer, total: {}", worker_name, count); - } - None => { - log::info!("[{}] Event channel closed, flushing remaining events", worker_name); - - // Flush any remaining events before exiting - if !buffer.is_empty() { - let events = buffer.take_events(); - if let Err(e) = send_event_batch(events.clone()).await { - log::warn!("[{}] Failed to send final event batch: {}, storing locally", worker_name, e); - buffer.add_failed_events(events); - } - } - - // Also try to flush any remaining failed events - if buffer.has_failed_events() { - let failed_events = buffer.take_failed_events(); - log::warn!("[{}] Storing {} failed events locally (endpoint unavailable)", worker_name, failed_events.len()); - } - - break; - } - } - } - - // Periodic flush check - _ = flush_interval.tick() => { - let config = { - let config_store = get_log_sender_config(); - let config_guard = config_store.read().unwrap(); - config_guard.as_ref().cloned() - }; - - if let Some(config) = config { - // Handle regular event flushing - if buffer.should_flush(&config) { - let events = buffer.take_events(); - if !events.is_empty() { - log::debug!("[{}] Flushing event batch: {} events", worker_name, events.len()); - log::debug!("[{}] Events: {:?}", worker_name, events); - if let Err(e) = send_event_batch(events.clone()).await { - log::warn!("[{}] Failed to send event batch: {}, storing locally for retry", worker_name, e); - buffer.add_failed_events(events); - } - } - } - - // Handle retry of failed events - if buffer.should_retry_failed_events() { - let failed_events = buffer.take_failed_events(); - if !failed_events.is_empty() { - log::debug!("[{}] Retrying failed event batch: {} events", worker_name, failed_events.len()); - if let Err(e) = send_event_batch(failed_events.clone()).await { - log::warn!("[{}] Failed to retry event batch: {}, storing locally again", worker_name, e); - buffer.add_failed_events(failed_events); - } - } - } - } - } - } - } - }) - } -} - + + /// Get the timestamp of the event + pub fn timestamp(&self) -> DateTime { + match self { + UnifiedEvent::HttpAccessLog(event) => event.timestamp, + UnifiedEvent::DroppedIp(event) => event.timestamp, + UnifiedEvent::TcpFingerprint(event) => event.timestamp, + UnifiedEvent::AgentStatus(event) => event.timestamp, + } + } + + /// Convert to JSON string + pub fn to_json(&self) -> Result { + serde_json::to_string(self) + } +} + +/// Buffer for storing events before batch sending +#[derive(Debug)] +struct EventBuffer { + events: Vec, + failed_events: Vec, // Store events that failed to send + total_size_bytes: usize, + failed_size_bytes: usize, + last_flush_time: Instant, + last_retry_time: Instant, // Track when we last tried to resend failed events +} + +impl EventBuffer { + fn new() -> Self { + Self { + events: Vec::new(), + failed_events: Vec::new(), + total_size_bytes: 0, + failed_size_bytes: 0, + last_flush_time: Instant::now(), + last_retry_time: Instant::now(), + } + } + + fn add_event(&mut self, event: UnifiedEvent) -> usize { + // Estimate event size (rough approximation) + let event_size = estimate_event_size(&event); + self.events.push(event); + self.total_size_bytes += event_size; + self.events.len() + } + + fn should_flush(&self, config: &LogSenderConfig) -> bool { + self.events.len() >= config.batch_size_limit || + self.total_size_bytes >= config.batch_size_bytes || + self.last_flush_time.elapsed().as_secs() >= config.batch_timeout_secs + } + + fn take_events(&mut self) -> Vec { + self.total_size_bytes = 0; + self.last_flush_time = Instant::now(); + std::mem::take(&mut self.events) + } + + fn is_empty(&self) -> bool { + self.events.is_empty() + } + + fn add_failed_events(&mut self, events: Vec) { + // Calculate how many events we can add without exceeding the limit + let current_count = self.failed_events.len(); + let available_capacity = MAX_FAILED_EVENTS.saturating_sub(current_count); + + if available_capacity == 0 { + // Buffer full, drop oldest events to make room + let events_to_drop = events.len().min(self.failed_events.len()); + for _ in 0..events_to_drop { + if let Some(dropped) = self.failed_events.first() { + self.failed_size_bytes = self.failed_size_bytes.saturating_sub(estimate_event_size(dropped)); + } + self.failed_events.remove(0); + } + log::warn!("Failed events buffer full, dropped {} oldest events", events_to_drop); + } + + // Add new events (up to the limit) + let events_to_add = events.into_iter().take(MAX_FAILED_EVENTS.saturating_sub(self.failed_events.len())); + for event in events_to_add { + let event_size = estimate_event_size(&event); + self.failed_events.push(event); + self.failed_size_bytes += event_size; + } + } + + fn should_retry_failed_events(&self) -> bool { + // Retry failed events every 30 seconds + !self.failed_events.is_empty() && + self.last_retry_time.elapsed().as_secs() >= 30 + } + + fn take_failed_events(&mut self) -> Vec { + self.failed_size_bytes = 0; + self.last_retry_time = Instant::now(); + std::mem::take(&mut self.failed_events) + } + + fn has_failed_events(&self) -> bool { + !self.failed_events.is_empty() + } +} + +/// Estimate the size of an event in bytes +fn estimate_event_size(event: &UnifiedEvent) -> usize { + // Rough estimation based on JSON serialization + // This is an approximation - actual size may vary + let base_size = 500; // Base overhead + + match event { + UnifiedEvent::HttpAccessLog(log) => { + base_size + log.http.body.len() + log.response.body.len() + + log.http.headers.len() * 50 // Rough estimate for headers + } + UnifiedEvent::DroppedIp(_) => base_size + 200, // Dropped IP events are relatively small + UnifiedEvent::TcpFingerprint(_) => base_size + 100, // TCP fingerprint events are small + UnifiedEvent::AgentStatus(_) => base_size + 300, // Agent status events are small + } +} + +/// Channel for sending events to the batch processor +static EVENT_CHANNEL: std::sync::OnceLock> = std::sync::OnceLock::new(); + +pub fn get_event_channel() -> Option<&'static mpsc::UnboundedSender> { + EVENT_CHANNEL.get() +} + +pub fn set_event_channel(sender: mpsc::UnboundedSender) { + let _ = EVENT_CHANNEL.set(sender); +} + +/// Send an event to the unified queue +pub fn send_event(event: UnifiedEvent) { + if let Some(sender) = get_event_channel() { + if let Err(e) = sender.send(event) { + log::warn!("Failed to send event to queue: {}", e); + } + } else { + // Event channel not initialized - this is expected when log_sending_enabled is false + log::trace!("Event channel not initialized, skipping event queuing"); + } +} + +/// Send a batch of events to the /events endpoint +/// Automatically splits large batches into chunks of API_MAX_BATCH_SIZE +async fn send_event_batch(events: Vec) -> Result<(), Box> { + if events.is_empty() { + return Ok(()); + } + + let config = { + let config_store = get_log_sender_config(); + let config_guard = config_store.read().unwrap(); + config_guard.as_ref().cloned() + }; + + let config = match config { + Some(config) => { + if !config.should_send_logs() { + return Ok(()); + } + config + } + None => return Ok(()), + }; + + // Use shared HTTP client with keepalive instead of creating new client + let client = http_client::get_global_reqwest_client() + .map_err(|e| format!("Failed to get global HTTP client: {}", e))?; + + let url = format!("{}/events", config.base_url); + + // Split events into chunks of API_MAX_BATCH_SIZE to respect API limits + let chunks: Vec<_> = events.chunks(API_MAX_BATCH_SIZE).collect(); + let total_events = events.len(); + + for (chunk_idx, chunk) in chunks.iter().enumerate() { + let json = serde_json::to_string(chunk)?; + + log::debug!("Sending chunk {}/{} ({} events) to {}", + chunk_idx + 1, chunks.len(), chunk.len(), url); + + let response = client + .post(&url) + .header("Authorization", format!("Bearer {}", config.api_key)) + .header("Content-Type", "application/json") + .body(json) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); + log::warn!("Failed to send event batch chunk {}/{} to /events endpoint: {} - {} (chunk size: {}, total batch: {})", + chunk_idx + 1, chunks.len(), status, error_text, chunk.len(), total_events); + return Err(format!("HTTP {}: {}", status, error_text).into()); + } else { + log::debug!("Successfully sent event batch chunk {}/{} to /events endpoint (chunk size: {})", + chunk_idx + 1, chunks.len(), chunk.len()); + } + } + + log::debug!("Successfully sent all {} events in {} chunk(s) to /events endpoint", total_events, chunks.len()); + Ok(()) +} + +/// Log sender worker that batches and sends events +pub struct LogSenderWorker { + check_interval_secs: u64, +} + +impl LogSenderWorker { + pub fn new(check_interval_secs: u64) -> Self { + Self { + check_interval_secs, + } + } +} + +impl super::Worker for LogSenderWorker { + fn name(&self) -> &str { + "log_sender" + } + + fn run(&self, mut shutdown: watch::Receiver) -> tokio::task::JoinHandle<()> { + let check_interval_secs = self.check_interval_secs; + let worker_name = self.name().to_string(); + + // Create event channel + let (sender, mut receiver) = mpsc::unbounded_channel::(); + set_event_channel(sender); + + tokio::spawn(async move { + log::info!("[{}] Starting log sender worker", worker_name); + + let mut buffer = EventBuffer::new(); + let mut flush_interval = interval(Duration::from_secs(check_interval_secs)); + flush_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + loop { + tokio::select! { + _ = shutdown.changed() => { + if *shutdown.borrow() { + log::info!("[{}] Shutdown signal received, flushing remaining events", worker_name); + + // Flush any remaining events before exiting + if !buffer.is_empty() { + let events = buffer.take_events(); + if let Err(e) = send_event_batch(events.clone()).await { + log::warn!("[{}] Failed to send final event batch: {}, storing locally", worker_name, e); + buffer.add_failed_events(events); + } + } + + // Also try to flush any remaining failed events + if buffer.has_failed_events() { + let failed_events = buffer.take_failed_events(); + log::warn!("[{}] Storing {} failed events locally (endpoint unavailable)", worker_name, failed_events.len()); + } + + log::info!("[{}] Log sender worker stopped", worker_name); + break; + } + } + + // Receive new events + event = receiver.recv() => { + match event { + Some(event) => { + let count = buffer.add_event(event); + log::trace!("[{}] Added event to buffer, total: {}", worker_name, count); + } + None => { + log::info!("[{}] Event channel closed, flushing remaining events", worker_name); + + // Flush any remaining events before exiting + if !buffer.is_empty() { + let events = buffer.take_events(); + if let Err(e) = send_event_batch(events.clone()).await { + log::warn!("[{}] Failed to send final event batch: {}, storing locally", worker_name, e); + buffer.add_failed_events(events); + } + } + + // Also try to flush any remaining failed events + if buffer.has_failed_events() { + let failed_events = buffer.take_failed_events(); + log::warn!("[{}] Storing {} failed events locally (endpoint unavailable)", worker_name, failed_events.len()); + } + + break; + } + } + } + + // Periodic flush check + _ = flush_interval.tick() => { + let config = { + let config_store = get_log_sender_config(); + let config_guard = config_store.read().unwrap(); + config_guard.as_ref().cloned() + }; + + if let Some(config) = config { + // Handle regular event flushing + if buffer.should_flush(&config) { + let events = buffer.take_events(); + if !events.is_empty() { + log::debug!("[{}] Flushing event batch: {} events", worker_name, events.len()); + log::debug!("[{}] Events: {:?}", worker_name, events); + if let Err(e) = send_event_batch(events.clone()).await { + log::warn!("[{}] Failed to send event batch: {}, storing locally for retry", worker_name, e); + buffer.add_failed_events(events); + } + } + } + + // Handle retry of failed events + if buffer.should_retry_failed_events() { + let failed_events = buffer.take_failed_events(); + if !failed_events.is_empty() { + log::debug!("[{}] Retrying failed event batch: {} events", worker_name, failed_events.len()); + if let Err(e) = send_event_batch(failed_events.clone()).await { + log::warn!("[{}] Failed to retry event batch: {}, storing locally again", worker_name, e); + buffer.add_failed_events(failed_events); + } + } + } + } + } + } + } + }) + } +} + diff --git a/src/worker/manager.rs b/src/worker/manager.rs index 7a2df95..5f439f9 100644 --- a/src/worker/manager.rs +++ b/src/worker/manager.rs @@ -1,99 +1,99 @@ -use std::collections::HashMap; -use tokio::sync::watch; -use tokio::task::JoinHandle; - -/// Worker trait that all workers must implement -pub trait Worker: Send + Sync + 'static { - /// Name of the worker - fn name(&self) -> &str; - - /// Run the worker task - fn run(&self, shutdown: watch::Receiver) -> JoinHandle<()>; -} - -/// Worker configuration -#[derive(Debug, Clone)] -pub struct WorkerConfig { - /// Worker name/identifier - pub name: String, - /// Schedule interval in seconds - pub interval_secs: u64, - /// Whether the worker is enabled - pub enabled: bool, -} - -/// Worker manager that manages multiple workers with their own schedules -pub struct WorkerManager { - workers: HashMap)>, - shutdown_tx: watch::Sender, -} - -impl WorkerManager { - /// Create a new worker manager - pub fn new() -> (Self, watch::Receiver) { - let (shutdown_tx, shutdown_rx) = watch::channel(false); - ( - Self { - workers: HashMap::new(), - shutdown_tx, - }, - shutdown_rx, - ) - } - - /// Register a worker with its configuration - pub fn register_worker( - &mut self, - config: WorkerConfig, - worker: W, - ) -> Result<(), String> { - if !config.enabled { - log::info!("Worker '{}' is disabled, skipping registration", config.name); - return Ok(()); - } - - if self.workers.contains_key(&config.name) { - return Err(format!("Worker '{}' is already registered", config.name)); - } - - let shutdown_rx = self.shutdown_tx.subscribe(); - let handle = worker.run(shutdown_rx); - - self.workers.insert(config.name.clone(), (config, handle)); - Ok(()) - } - - /// Get all worker handles for graceful shutdown - pub fn get_handles(&self) -> Vec<(&str, &JoinHandle<()>)> { - self.workers - .iter() - .map(|(name, (_, handle))| (name.as_str(), handle)) - .collect() - } - - /// Shutdown all workers - pub fn shutdown(&mut self) { - log::info!("Shutting down {} workers...", self.workers.len()); - let _ = self.shutdown_tx.send(true); - } - - /// Wait for all workers to complete - pub async fn wait_for_all(&mut self) { - let handles: Vec<_> = self.workers.drain().map(|(_, (_, handle))| handle).collect(); - - for handle in handles { - if let Err(e) = handle.await { - log::error!("Worker task join error: {}", e); - } - } - - log::info!("All workers stopped"); - } -} - -impl Default for WorkerManager { - fn default() -> Self { - Self::new().0 - } -} - +use std::collections::HashMap; +use tokio::sync::watch; +use tokio::task::JoinHandle; + +/// Worker trait that all workers must implement +pub trait Worker: Send + Sync + 'static { + /// Name of the worker + fn name(&self) -> &str; + + /// Run the worker task + fn run(&self, shutdown: watch::Receiver) -> JoinHandle<()>; +} + +/// Worker configuration +#[derive(Debug, Clone)] +pub struct WorkerConfig { + /// Worker name/identifier + pub name: String, + /// Schedule interval in seconds + pub interval_secs: u64, + /// Whether the worker is enabled + pub enabled: bool, +} + +/// Worker manager that manages multiple workers with their own schedules +pub struct WorkerManager { + workers: HashMap)>, + shutdown_tx: watch::Sender, +} + +impl WorkerManager { + /// Create a new worker manager + pub fn new() -> (Self, watch::Receiver) { + let (shutdown_tx, shutdown_rx) = watch::channel(false); + ( + Self { + workers: HashMap::new(), + shutdown_tx, + }, + shutdown_rx, + ) + } + + /// Register a worker with its configuration + pub fn register_worker( + &mut self, + config: WorkerConfig, + worker: W, + ) -> Result<(), String> { + if !config.enabled { + log::info!("Worker '{}' is disabled, skipping registration", config.name); + return Ok(()); + } + + if self.workers.contains_key(&config.name) { + return Err(format!("Worker '{}' is already registered", config.name)); + } + + let shutdown_rx = self.shutdown_tx.subscribe(); + let handle = worker.run(shutdown_rx); + + self.workers.insert(config.name.clone(), (config, handle)); + Ok(()) + } + + /// Get all worker handles for graceful shutdown + pub fn get_handles(&self) -> Vec<(&str, &JoinHandle<()>)> { + self.workers + .iter() + .map(|(name, (_, handle))| (name.as_str(), handle)) + .collect() + } + + /// Shutdown all workers + pub fn shutdown(&mut self) { + log::info!("Shutting down {} workers...", self.workers.len()); + let _ = self.shutdown_tx.send(true); + } + + /// Wait for all workers to complete + pub async fn wait_for_all(&mut self) { + let handles: Vec<_> = self.workers.drain().map(|(_, (_, handle))| handle).collect(); + + for handle in handles { + if let Err(e) = handle.await { + log::error!("Worker task join error: {}", e); + } + } + + log::info!("All workers stopped"); + } +} + +impl Default for WorkerManager { + fn default() -> Self { + Self::new().0 + } +} + diff --git a/src/worker/mod.rs b/src/worker/mod.rs index 01def73..cfdf11c 100644 --- a/src/worker/mod.rs +++ b/src/worker/mod.rs @@ -1,10 +1,10 @@ -#[cfg(feature = "proxy")] -pub mod certificate; -pub mod config; -pub mod geoip_mmdb; -pub mod log; -pub mod manager; -pub mod threat_mmdb; - -pub use manager::{Worker, WorkerConfig, WorkerManager}; -pub mod agent_status; +#[cfg(feature = "proxy")] +pub mod certificate; +pub mod config; +pub mod geoip_mmdb; +pub mod log; +pub mod manager; +pub mod threat_mmdb; + +pub use manager::{Worker, WorkerConfig, WorkerManager}; +pub mod agent_status; diff --git a/src/worker/threat_mmdb.rs b/src/worker/threat_mmdb.rs index 1110c1a..dc4452e 100644 --- a/src/worker/threat_mmdb.rs +++ b/src/worker/threat_mmdb.rs @@ -1,333 +1,333 @@ -use std::collections::HashMap; -use std::path::PathBuf; -use std::time::Duration; - -use anyhow::{anyhow, Context, Result}; -use serde::Deserialize; -use tokio::sync::watch; -use tokio::task::JoinHandle; -use tokio::time::{interval, Duration as TokioDuration, MissedTickBehavior}; - -use crate::http_client::get_global_reqwest_client; -use crate::threat; -use crate::worker::Worker; - -#[derive(Debug, Deserialize)] -struct VersionResponse { - success: bool, - #[allow(dead_code)] - timestamp: String, - version: String, - #[allow(dead_code)] - hash: String, -} - -/// Periodically checks version.txt and downloads the latest Threat MMDB, -/// then asks the threat module to reload it from disk. -pub struct ThreatMmdbWorker { - interval_secs: u64, - mmdb_base_url: String, - mmdb_path: Option, - headers: Option>, - api_key: String, -} - -impl ThreatMmdbWorker { - pub fn new( - interval_secs: u64, - mmdb_base_url: String, - mmdb_path: Option, - headers: Option>, - api_key: String, - ) -> Self { - Self { - interval_secs, - mmdb_base_url, - mmdb_path, - headers, - api_key, - } - } -} - -impl Worker for ThreatMmdbWorker { - fn name(&self) -> &str { - "threat_mmdb" - } - - fn run(&self, mut shutdown: watch::Receiver) -> JoinHandle<()> { - let interval_secs = self.interval_secs; - let mmdb_base_url = self.mmdb_base_url.clone(); - let mmdb_path = self.mmdb_path.clone(); - let mut headers = self.headers.clone().unwrap_or_default(); - let api_key = self.api_key.clone(); - - // Always add API key to headers if provided - if !api_key.is_empty() { - headers.insert("Authorization".to_string(), format!("Bearer {}", api_key)); - } - - let headers = if headers.is_empty() { None } else { Some(headers) }; - - tokio::spawn(async move { - let worker_name = "threat_mmdb".to_string(); - let mut current_version: Option = None; - - log::info!( - "[{}] Starting Threat MMDB worker (interval: {}s)", - worker_name, - interval_secs - ); - - // Initial sync on startup - match sync_mmdb( - &mmdb_base_url, - mmdb_path.clone(), - current_version.clone(), - headers.clone(), - ) - .await - { - Ok(new_ver) => { - current_version = new_ver; - } - Err(e) => { - log::warn!( - "[{}] Initial Threat MMDB sync failed: {}", - worker_name, - e - ); - } - } - - let mut ticker = interval(TokioDuration::from_secs(interval_secs)); - ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = ticker.tick() => { - match sync_mmdb( - &mmdb_base_url, - mmdb_path.clone(), - current_version.clone(), - headers.clone(), - ).await { - Ok(new_ver) => { - if new_ver != current_version { - current_version = new_ver; - log::info!("[{}] Threat MMDB updated to latest version", worker_name); - } else { - log::debug!("[{}] Threat MMDB already at latest version", worker_name); - } - } - Err(e) => { - log::warn!("[{}] Periodic Threat MMDB sync failed: {}", worker_name, e); - } - } - } - _ = shutdown.changed() => { - if *shutdown.borrow() { - log::info!("[{}] Threat MMDB worker received shutdown signal", worker_name); - break; - } - } - } - } - }) - } -} - -/// One sync pass: check version API, download a new MMDB if needed, write it -/// to disk, and ask the threat module to reload it. -async fn sync_mmdb( - mmdb_base_url: &str, - mmdb_path: Option, - current_version: Option, - headers: Option>, -) -> Result> { - // Construct version API URL from base URL - let base = mmdb_base_url.trim_end_matches('/'); - let versions_url = format!("{}/indicators/version", base); - - // Check L1 cache first using pingora-memory-cache - let cache = threat::get_version_cache() - .context("Failed to get version cache")?; - let cache_key = format!("threat_mmdb_version:{}", base); - - // Try to get cached version - let (cached_version, cache_status) = cache.get(&cache_key); - - // Download version from API endpoint - let latest = if versions_url.starts_with("http://") || versions_url.starts_with("https://") { - let client = get_global_reqwest_client() - .context("Failed to get HTTP client for version download")?; - let mut req = client.get(&versions_url).timeout(Duration::from_secs(120)); - if let Some(ref hdrs) = headers { - for (key, value) in hdrs { - req = req.header(key, value); - } - } - let resp = req - .send() - .await - .with_context(|| format!("Failed to download version from {} (timeout: 120s)", versions_url))?; - let status = resp.status(); - if !status.is_success() { - return Err(anyhow!( - "Version download failed: status {} from {}", - status, - versions_url - )); - } - let version_resp: VersionResponse = resp - .json() - .await - .with_context(|| format!("Failed to parse version JSON from {}", versions_url))?; - - if !version_resp.success { - return Err(anyhow!("Version API returned success=false from {}", versions_url)); - } - - version_resp.version - } else { - // For file:// URLs, try to parse as JSON file - let file_path = versions_url.strip_prefix("file://").unwrap_or(&versions_url); - let content = tokio::fs::read_to_string(file_path) - .await - .with_context(|| format!("Failed to read version file from {}", file_path))?; - let version_resp: VersionResponse = serde_json::from_str(&content) - .with_context(|| format!("Failed to parse version JSON from {}", file_path))?; - - if !version_resp.success { - return Err(anyhow!("Version file returned success=false from {}", file_path)); - } - - version_resp.version - }; - - // Check if cached version matches latest version - if let Some(cached_ver) = cached_version { - if cache_status.is_hit() && cached_ver == latest { - // Version matches cache, skip download - log::debug!("Threat MMDB version {} matches cache, skipping download", latest); - return Ok(Some(latest)); - } - } - - // If we have a current_version and it matches latest, also skip - if let Some(ref curr) = current_version { - if curr == &latest { - // Already on latest, nothing to do. - return Ok(current_version); - } - } - - let mut local_path = mmdb_path.ok_or_else(|| anyhow!("MMDB path not configured for Threat MMDB worker"))?; - - // If the path doesn't have a file extension, treat it as a directory and append the filename - if local_path.extension().is_none() { - local_path = local_path.join("threat.mmdb"); - } - - let base = mmdb_base_url.trim_end_matches('/'); - let bytes = if base.starts_with("http://") || base.starts_with("https://") { - let url = format!("{}/indicators/download?version={}", base, latest); - download_mmdb(&url, headers.as_ref()).await? - } else { - // Strip file:// prefix if present - let file_base = base.strip_prefix("file://").unwrap_or(base); - let src_path = PathBuf::from(file_base).join(&latest); - tokio::fs::read(&src_path) - .await - .with_context(|| format!("Failed to read MMDB from {:?}", src_path))? - }; - - if let Some(parent) = local_path.parent() { - tokio::fs::create_dir_all(parent) - .await - .with_context(|| format!("Failed to create MMDB directory {:?}", parent))?; - } - - tokio::fs::write(&local_path, &bytes) - .await - .with_context(|| format!("Failed to write MMDB to {:?}", local_path))?; - - log::info!("Threat MMDB written to {:?}", local_path); - - // Update L1 cache with new version (no TTL, cache indefinitely) - cache.put(&cache_key, latest.clone(), None); - log::debug!("Updated threat MMDB version cache: {}", latest); - - // Ask the threat module to reload from the updated path. - crate::threat::refresh_threat_mmdb().await?; - - Ok(Some(latest)) -} - - -async fn download_mmdb(url: &str, headers: Option<&HashMap>) -> Result> { - let client = get_global_reqwest_client() - .context("Failed to get HTTP client for MMDB download")?; - - // Use a longer timeout for large files (120s base + 1s per MB) - // For 172MB file, that would be ~292s, but we'll use a fixed 600s (10min) for very large files - let timeout_secs = 600; // 10 minutes for large MMDB files - let mut req = client.get(url).timeout(Duration::from_secs(timeout_secs)); - if let Some(hdrs) = headers { - for (key, value) in hdrs { - req = req.header(key, value); - } - } - let resp = req - .send() - .await - .with_context(|| format!("Failed to download MMDB from {} (timeout: 120s)", url))?; - let status = resp.status(); - - // Get content length for logging (before consuming response) - let content_length = resp.content_length(); - - if !status.is_success() { - // Try to read error body, but don't fail if we can't - let status_text = match resp.text().await { - Ok(text) if !text.is_empty() => text, - _ => format!("HTTP {}", status), - }; - return Err(anyhow!( - "MMDB download failed: status {} from {} - {}", - status, - url, - status_text - )); - } - - let file_size_mb = content_length.map(|s| s as f64 / 1_048_576.0); - log::info!( - "Downloading MMDB from {} (content-length: {:?}, ~{:.2} MB, timeout: {}s)", - url, - content_length, - file_size_mb.unwrap_or(0.0), - timeout_secs - ); - - let bytes = resp - .bytes() - .await - .with_context(|| { - format!( - "Failed to read MMDB body from {} (content-length: {:?}, ~{:.2} MB). This may indicate a network timeout or connection issue.", - url, - content_length, - file_size_mb.unwrap_or(0.0) - ) - })?; - - let data = bytes.to_vec(); - log::debug!( - "Successfully downloaded MMDB from {} (size: {} bytes)", - url, - data.len() - ); - - Ok(data) -} +use std::collections::HashMap; +use std::path::PathBuf; +use std::time::Duration; + +use anyhow::{anyhow, Context, Result}; +use serde::Deserialize; +use tokio::sync::watch; +use tokio::task::JoinHandle; +use tokio::time::{interval, Duration as TokioDuration, MissedTickBehavior}; + +use crate::http_client::get_global_reqwest_client; +use crate::threat; +use crate::worker::Worker; + +#[derive(Debug, Deserialize)] +struct VersionResponse { + success: bool, + #[allow(dead_code)] + timestamp: String, + version: String, + #[allow(dead_code)] + hash: String, +} + +/// Periodically checks version.txt and downloads the latest Threat MMDB, +/// then asks the threat module to reload it from disk. +pub struct ThreatMmdbWorker { + interval_secs: u64, + mmdb_base_url: String, + mmdb_path: Option, + headers: Option>, + api_key: String, +} + +impl ThreatMmdbWorker { + pub fn new( + interval_secs: u64, + mmdb_base_url: String, + mmdb_path: Option, + headers: Option>, + api_key: String, + ) -> Self { + Self { + interval_secs, + mmdb_base_url, + mmdb_path, + headers, + api_key, + } + } +} + +impl Worker for ThreatMmdbWorker { + fn name(&self) -> &str { + "threat_mmdb" + } + + fn run(&self, mut shutdown: watch::Receiver) -> JoinHandle<()> { + let interval_secs = self.interval_secs; + let mmdb_base_url = self.mmdb_base_url.clone(); + let mmdb_path = self.mmdb_path.clone(); + let mut headers = self.headers.clone().unwrap_or_default(); + let api_key = self.api_key.clone(); + + // Always add API key to headers if provided + if !api_key.is_empty() { + headers.insert("Authorization".to_string(), format!("Bearer {}", api_key)); + } + + let headers = if headers.is_empty() { None } else { Some(headers) }; + + tokio::spawn(async move { + let worker_name = "threat_mmdb".to_string(); + let mut current_version: Option = None; + + log::info!( + "[{}] Starting Threat MMDB worker (interval: {}s)", + worker_name, + interval_secs + ); + + // Initial sync on startup + match sync_mmdb( + &mmdb_base_url, + mmdb_path.clone(), + current_version.clone(), + headers.clone(), + ) + .await + { + Ok(new_ver) => { + current_version = new_ver; + } + Err(e) => { + log::warn!( + "[{}] Initial Threat MMDB sync failed: {}", + worker_name, + e + ); + } + } + + let mut ticker = interval(TokioDuration::from_secs(interval_secs)); + ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = ticker.tick() => { + match sync_mmdb( + &mmdb_base_url, + mmdb_path.clone(), + current_version.clone(), + headers.clone(), + ).await { + Ok(new_ver) => { + if new_ver != current_version { + current_version = new_ver; + log::info!("[{}] Threat MMDB updated to latest version", worker_name); + } else { + log::debug!("[{}] Threat MMDB already at latest version", worker_name); + } + } + Err(e) => { + log::warn!("[{}] Periodic Threat MMDB sync failed: {}", worker_name, e); + } + } + } + _ = shutdown.changed() => { + if *shutdown.borrow() { + log::info!("[{}] Threat MMDB worker received shutdown signal", worker_name); + break; + } + } + } + } + }) + } +} + +/// One sync pass: check version API, download a new MMDB if needed, write it +/// to disk, and ask the threat module to reload it. +async fn sync_mmdb( + mmdb_base_url: &str, + mmdb_path: Option, + current_version: Option, + headers: Option>, +) -> Result> { + // Construct version API URL from base URL + let base = mmdb_base_url.trim_end_matches('/'); + let versions_url = format!("{}/indicators/version", base); + + // Check L1 cache first using pingora-memory-cache + let cache = threat::get_version_cache() + .context("Failed to get version cache")?; + let cache_key = format!("threat_mmdb_version:{}", base); + + // Try to get cached version + let (cached_version, cache_status) = cache.get(&cache_key); + + // Download version from API endpoint + let latest = if versions_url.starts_with("http://") || versions_url.starts_with("https://") { + let client = get_global_reqwest_client() + .context("Failed to get HTTP client for version download")?; + let mut req = client.get(&versions_url).timeout(Duration::from_secs(120)); + if let Some(ref hdrs) = headers { + for (key, value) in hdrs { + req = req.header(key, value); + } + } + let resp = req + .send() + .await + .with_context(|| format!("Failed to download version from {} (timeout: 120s)", versions_url))?; + let status = resp.status(); + if !status.is_success() { + return Err(anyhow!( + "Version download failed: status {} from {}", + status, + versions_url + )); + } + let version_resp: VersionResponse = resp + .json() + .await + .with_context(|| format!("Failed to parse version JSON from {}", versions_url))?; + + if !version_resp.success { + return Err(anyhow!("Version API returned success=false from {}", versions_url)); + } + + version_resp.version + } else { + // For file:// URLs, try to parse as JSON file + let file_path = versions_url.strip_prefix("file://").unwrap_or(&versions_url); + let content = tokio::fs::read_to_string(file_path) + .await + .with_context(|| format!("Failed to read version file from {}", file_path))?; + let version_resp: VersionResponse = serde_json::from_str(&content) + .with_context(|| format!("Failed to parse version JSON from {}", file_path))?; + + if !version_resp.success { + return Err(anyhow!("Version file returned success=false from {}", file_path)); + } + + version_resp.version + }; + + // Check if cached version matches latest version + if let Some(cached_ver) = cached_version { + if cache_status.is_hit() && cached_ver == latest { + // Version matches cache, skip download + log::debug!("Threat MMDB version {} matches cache, skipping download", latest); + return Ok(Some(latest)); + } + } + + // If we have a current_version and it matches latest, also skip + if let Some(ref curr) = current_version { + if curr == &latest { + // Already on latest, nothing to do. + return Ok(current_version); + } + } + + let mut local_path = mmdb_path.ok_or_else(|| anyhow!("MMDB path not configured for Threat MMDB worker"))?; + + // If the path doesn't have a file extension, treat it as a directory and append the filename + if local_path.extension().is_none() { + local_path = local_path.join("threat.mmdb"); + } + + let base = mmdb_base_url.trim_end_matches('/'); + let bytes = if base.starts_with("http://") || base.starts_with("https://") { + let url = format!("{}/indicators/download?version={}", base, latest); + download_mmdb(&url, headers.as_ref()).await? + } else { + // Strip file:// prefix if present + let file_base = base.strip_prefix("file://").unwrap_or(base); + let src_path = PathBuf::from(file_base).join(&latest); + tokio::fs::read(&src_path) + .await + .with_context(|| format!("Failed to read MMDB from {:?}", src_path))? + }; + + if let Some(parent) = local_path.parent() { + tokio::fs::create_dir_all(parent) + .await + .with_context(|| format!("Failed to create MMDB directory {:?}", parent))?; + } + + tokio::fs::write(&local_path, &bytes) + .await + .with_context(|| format!("Failed to write MMDB to {:?}", local_path))?; + + log::info!("Threat MMDB written to {:?}", local_path); + + // Update L1 cache with new version (no TTL, cache indefinitely) + cache.put(&cache_key, latest.clone(), None); + log::debug!("Updated threat MMDB version cache: {}", latest); + + // Ask the threat module to reload from the updated path. + crate::threat::refresh_threat_mmdb().await?; + + Ok(Some(latest)) +} + + +async fn download_mmdb(url: &str, headers: Option<&HashMap>) -> Result> { + let client = get_global_reqwest_client() + .context("Failed to get HTTP client for MMDB download")?; + + // Use a longer timeout for large files (120s base + 1s per MB) + // For 172MB file, that would be ~292s, but we'll use a fixed 600s (10min) for very large files + let timeout_secs = 600; // 10 minutes for large MMDB files + let mut req = client.get(url).timeout(Duration::from_secs(timeout_secs)); + if let Some(hdrs) = headers { + for (key, value) in hdrs { + req = req.header(key, value); + } + } + let resp = req + .send() + .await + .with_context(|| format!("Failed to download MMDB from {} (timeout: 120s)", url))?; + let status = resp.status(); + + // Get content length for logging (before consuming response) + let content_length = resp.content_length(); + + if !status.is_success() { + // Try to read error body, but don't fail if we can't + let status_text = match resp.text().await { + Ok(text) if !text.is_empty() => text, + _ => format!("HTTP {}", status), + }; + return Err(anyhow!( + "MMDB download failed: status {} from {} - {}", + status, + url, + status_text + )); + } + + let file_size_mb = content_length.map(|s| s as f64 / 1_048_576.0); + log::info!( + "Downloading MMDB from {} (content-length: {:?}, ~{:.2} MB, timeout: {}s)", + url, + content_length, + file_size_mb.unwrap_or(0.0), + timeout_secs + ); + + let bytes = resp + .bytes() + .await + .with_context(|| { + format!( + "Failed to read MMDB body from {} (content-length: {:?}, ~{:.2} MB). This may indicate a network timeout or connection issue.", + url, + content_length, + file_size_mb.unwrap_or(0.0) + ) + })?; + + let data = bytes.to_vec(); + log::debug!( + "Successfully downloaded MMDB from {} (size: {} bytes)", + url, + data.len() + ); + + Ok(data) +} From e6468c0b3264808ae54a752365e6db02a09b202c Mon Sep 17 00:00:00 2001 From: krichard1212 <136473183+krichard1212@users.noreply.github.com> Date: Tue, 3 Feb 2026 19:19:23 +0100 Subject: [PATCH 10/14] merge conflict fix --- config/config.yaml | 12 ++++++++++++ src/cli.rs | 26 +++++++++++++++++++++++--- src/main.rs | 42 ++++++++++++++++++++---------------------- 3 files changed, 55 insertions(+), 25 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 33a3980..d366bc5 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -7,6 +7,8 @@ # # Available environment variables (without deprecated AX_ prefix): # - MODE (agent or proxy) +# - MULTI_THREAD (true or false, enables multi-threaded runtime) +# - WORKER_THREADS (number of worker threads when multi_thread is enabled) # - REDIS_URL, REDIS_PREFIX # - REDIS_SSL_CA_CERT_PATH, REDIS_SSL_CLIENT_CERT_PATH, REDIS_SSL_CLIENT_KEY_PATH, REDIS_SSL_INSECURE # - NETWORK_IFACE, NETWORK_IFACES (comma-separated), NETWORK_DISABLE_XDP, NETWORK_IP_VERSION @@ -30,6 +32,16 @@ # - proxy: Full reverse proxy functionality (default) mode: "proxy" +# Multi-thread runtime configuration +# - true: Use multi-threaded Tokio runtime (better for proxy mode with high concurrency) +# - false: Use single-threaded runtime (lower memory, better for agent mode) +# Default: false for agent mode, true for proxy mode +# multi_thread: true + +# Number of worker threads when multi_thread is enabled +# Default: number of CPU cores +# worker_threads: 4 + # Redis Configuration redis: # Redis connection URL for ACME cache storage diff --git a/src/cli.rs b/src/cli.rs index 7481757..c09788a 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -30,9 +30,14 @@ pub struct Config { #[serde(default = "default_mode")] pub mode: String, - /// Number of Tokio worker threads. In agent mode, defaults to 0 (single-threaded). - /// Set to a positive number to force multi-threaded runtime. - /// In proxy mode, defaults to number of CPUs. + /// Enable multi-threaded runtime. Defaults to false for agent mode, true for proxy mode. + /// When disabled, uses single-threaded runtime for lower memory usage. + #[serde(default)] + pub multi_thread: Option, + + /// Number of Tokio worker threads when multi_thread is enabled. + /// Defaults to number of CPUs if not specified. + /// Ignored when multi_thread is false. #[serde(default)] pub worker_threads: Option, @@ -269,6 +274,7 @@ impl Config { pub fn default() -> Self { Self { mode: "agent".to_string(), + multi_thread: None, worker_threads: None, redis: RedisConfig { url: "redis://127.0.0.1/0".to_string(), @@ -507,6 +513,20 @@ impl Config { self.mode = val; } + // Multi-thread override + if let Some(val) = get_env("MULTI_THREAD") { + if let Ok(parsed) = val.parse::() { + self.multi_thread = Some(parsed); + } + } + + // Worker threads override + if let Some(val) = get_env("WORKER_THREADS") { + if let Ok(parsed) = val.parse::() { + self.worker_threads = Some(parsed); + } + } + // Redis configuration overrides if let Some(val) = get_env("REDIS_URL") { self.redis.url = val; diff --git a/src/main.rs b/src/main.rs index 208b78e..2ba686d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -228,29 +228,27 @@ fn main() -> Result<()> { } // Start the tokio runtime and run the async application - let runtime = match config.worker_threads { - // Explicit config: use specified thread count (0 = single-threaded) - Some(0) => { - tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()? - } - Some(threads) => { - tokio::runtime::Builder::new_multi_thread() - .worker_threads(threads) - .enable_all() - .build()? - } - None if config.mode == "agent" => { - tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()? - } - None => { - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build()? + // Determine if we should use multi-threaded runtime + let use_multi_thread = config.multi_thread.unwrap_or_else(|| { + // Default: agent mode = single-threaded, proxy mode = multi-threaded + config.mode != "agent" + }); + + let runtime = if use_multi_thread { + // Multi-threaded runtime + let mut builder = tokio::runtime::Builder::new_multi_thread(); + builder.enable_all(); + if let Some(threads) = config.worker_threads { + if threads > 0 { + builder.worker_threads(threads); + } } + builder.build()? + } else { + // Single-threaded runtime (default for agent mode) + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()? }; runtime.block_on(async_main(args, config)) } From bc1bbaf9b7c5690e6297a6c1a09679c194a9edd9 Mon Sep 17 00:00:00 2001 From: krichard1212 <136473183+krichard1212@users.noreply.github.com> Date: Tue, 3 Feb 2026 19:31:30 +0100 Subject: [PATCH 11/14] build fix --- src/main.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main.rs b/src/main.rs index 2ba686d..578e8f5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1085,6 +1085,7 @@ async fn async_main(args: Args, config: Config) -> Result<()> { std::thread::spawn(move || { http_proxy::start::run_with_config(Some(crate::cli::Config { mode: "proxy".to_string(), + multi_thread: None, worker_threads: None, redis: Default::default(), network: network_config, From 0dd60ad378d5e3abaa506e1e9c52c1462fa15bab Mon Sep 17 00:00:00 2001 From: krichard1212 <136473183+krichard1212@users.noreply.github.com> Date: Wed, 4 Feb 2026 12:25:06 +0100 Subject: [PATCH 12/14] feat(agent-status): add kernel_version + linux distro metadata --- src/main.rs | 103 +++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 86 insertions(+), 17 deletions(-) diff --git a/src/main.rs b/src/main.rs index 578e8f5..bf05947 100644 --- a/src/main.rs +++ b/src/main.rs @@ -985,12 +985,13 @@ async fn async_main(args: Args, config: Config) -> Result<()> { }) .unwrap_or_default(); - let mut metadata = HashMap::new(); - metadata.insert("os".to_string(), std::env::consts::OS.to_string()); - metadata.insert("arch".to_string(), std::env::consts::ARCH.to_string()); - metadata.insert("version".to_string(), env!("CARGO_PKG_VERSION").to_string()); - metadata.insert("mode".to_string(), config.mode.clone()); - metadata.insert("platform_base_url".to_string(), config.platform.base_url.clone()); + let mut metadata = HashMap::new(); + metadata.insert("os".to_string(), std::env::consts::OS.to_string()); + metadata.insert("arch".to_string(), std::env::consts::ARCH.to_string()); + metadata.insert("version".to_string(), env!("CARGO_PKG_VERSION").to_string()); + metadata.insert("mode".to_string(), config.mode.clone()); + metadata.insert("platform_base_url".to_string(), config.platform.base_url.clone()); + add_platform_metadata(&mut metadata); let identity = AgentStatusIdentity { agent_id, @@ -1265,17 +1266,85 @@ async fn async_main(args: Args, config: Config) -> Result<()> { std::process::exit(0); } -fn read_env_non_empty(name: &str) -> Option { - std::env::var(name) - .ok() - .map(|value| value.trim().to_string()) - .filter(|value| !value.is_empty()) -} - -fn build_agent_id(agent_name: &str, workspace_id: &str) -> String { - if read_env_non_empty("AGENT_ID").is_some() { - warn!("AGENT_ID is ignored; agent_id is derived from agent_name + workspace_id."); - } +fn read_env_non_empty(name: &str) -> Option { + std::env::var(name) + .ok() + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) +} + +fn add_platform_metadata(metadata: &mut HashMap) { + // Requested extra fields: kernel version, linux version, linux type. + // Keep these in `metadata` so Arxignis doesn't need schema changes. + + #[cfg(unix)] + { + if let Ok(uts) = nix::sys::utsname::uname() { + // "kernel version" here is the kernel release (e.g. 6.10.14-linuxkit). + metadata.insert( + "kernel_version".to_string(), + uts.release().to_string_lossy().into_owned(), + ); + } + } + + #[cfg(target_os = "linux")] + { + // Always include these keys on Linux, even if we can't detect distro details. + let (linux_type, linux_version, pretty_name) = read_linux_os_release() + .unwrap_or_else(|| ("linux".to_string(), String::new(), String::new())); + + metadata.insert("linux_type".to_string(), linux_type); + metadata.insert("linux_version".to_string(), linux_version); + + // Extra convenience field for display/debug (optional). + if !pretty_name.is_empty() { + metadata.insert("linux_pretty_name".to_string(), pretty_name); + } + } +} + +#[cfg(target_os = "linux")] +fn read_linux_os_release() -> Option<(String, String, String)> { + // Common on most distributions; may be missing in very minimal containers. + let contents = std::fs::read_to_string("/etc/os-release").ok()?; + + let mut id = String::new(); + let mut version_id = String::new(); + let mut pretty_name = String::new(); + + for line in contents.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + + let Some((key, raw_value)) = line.split_once('=') else { + continue; + }; + + // Strip surrounding quotes if present. + let value = raw_value.trim().trim_matches('"').to_string(); + + match key { + "ID" => id = value, + "VERSION_ID" => version_id = value, + "PRETTY_NAME" => pretty_name = value, + _ => {} + } + } + + if id.is_empty() && version_id.is_empty() && pretty_name.is_empty() { + None + } else { + Some((id, version_id, pretty_name)) + } +} + +fn build_agent_id(agent_name: &str, workspace_id: &str) -> String { + if read_env_non_empty("AGENT_ID").is_some() { + warn!("AGENT_ID is ignored; agent_id is derived from agent_name + workspace_id."); + } if workspace_id.trim().is_empty() { warn!( From f5a49e88c4bce849d8911801987d93fb3c1e7013 Mon Sep 17 00:00:00 2001 From: krichard1212 <136473183+krichard1212@users.noreply.github.com> Date: Wed, 4 Feb 2026 12:45:35 +0100 Subject: [PATCH 13/14] fix(agent-status): read kernel_version from /proc to avoid nix utsname feature --- src/main.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/main.rs b/src/main.rs index bf05947..01d974a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1277,14 +1277,15 @@ fn add_platform_metadata(metadata: &mut HashMap) { // Requested extra fields: kernel version, linux version, linux type. // Keep these in `metadata` so Arxignis doesn't need schema changes. - #[cfg(unix)] + #[cfg(target_os = "linux")] { - if let Ok(uts) = nix::sys::utsname::uname() { - // "kernel version" here is the kernel release (e.g. 6.10.14-linuxkit). - metadata.insert( - "kernel_version".to_string(), - uts.release().to_string_lossy().into_owned(), - ); + // Avoid depending on nix::sys::utsname feature flags in minimal builds (e.g. CI agent-only). + // This is effectively `uname -r` and works in containers as long as /proc is mounted. + if let Ok(kernel_release) = std::fs::read_to_string("/proc/sys/kernel/osrelease") { + let kernel_release = kernel_release.trim(); + if !kernel_release.is_empty() { + metadata.insert("kernel_version".to_string(), kernel_release.to_string()); + } } } From b3be1b888ec75829250b15159e7e00f035d4dffc Mon Sep 17 00:00:00 2001 From: krichard1212 <136473183+krichard1212@users.noreply.github.com> Date: Wed, 4 Feb 2026 13:08:29 +0100 Subject: [PATCH 14/14] Added retry for apt-update --- pkg/deb/Dockerfile | 4 ++-- pkg/deb/docker/test.Dockerfile | 2 +- pkg/debug/build-legacy.Dockerfile | 4 ++-- pkg/debug/build.Dockerfile | 4 ++-- pkg/docker/Dockerfile | 6 +++--- pkg/docker/build.Dockerfile | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pkg/deb/Dockerfile b/pkg/deb/Dockerfile index 7fbc461..81060a5 100644 --- a/pkg/deb/Dockerfile +++ b/pkg/deb/Dockerfile @@ -6,7 +6,7 @@ FROM ${IMAGE}:${IMAGE_TAG} RUN sed -i '/updates/d' /etc/apt/sources.list && \ sed -i 's/httpredir/archive/' /etc/apt/sources.list -RUN apt-get update && \ +RUN (apt-get update || (sleep 5 && apt-get update)) && \ apt-get install -y --no-install-recommends \ ca-certificates \ curl \ @@ -14,7 +14,7 @@ RUN apt-get update && \ lsb-release && \ curl -fsSL https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - && \ echo "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-10 main" >> /etc/apt/sources.list.d/llvm.list && \ - apt-get update && \ + (apt-get update || (sleep 5 && apt-get update)) && \ apt-get install -y --no-install-recommends --allow-downgrades \ libc6=2.27-3ubuntu1.5 \ libc6-dev \ diff --git a/pkg/deb/docker/test.Dockerfile b/pkg/deb/docker/test.Dockerfile index d0e93f0..92b7b3a 100644 --- a/pkg/deb/docker/test.Dockerfile +++ b/pkg/deb/docker/test.Dockerfile @@ -1,6 +1,6 @@ FROM ubuntu:24.04 -RUN apt-get update && \ +RUN (apt-get update || (sleep 5 && apt-get update)) && \ apt-get install -y systemd systemd-sysv dbus curl python3 python3-pip sudo && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* diff --git a/pkg/debug/build-legacy.Dockerfile b/pkg/debug/build-legacy.Dockerfile index ce004f4..3f9b895 100644 --- a/pkg/debug/build-legacy.Dockerfile +++ b/pkg/debug/build-legacy.Dockerfile @@ -8,7 +8,7 @@ RUN sed -i '/updates/d' /etc/apt/sources.list && \ sed -i 's/httpredir/archive/' /etc/apt/sources.list && \ sed -i 's|https://|http://|g' /etc/apt/sources.list -RUN apt-get update && \ +RUN (apt-get update || (sleep 5 && apt-get update)) && \ apt-get install -y --no-install-recommends \ apt-transport-https \ ca-certificates \ @@ -17,7 +17,7 @@ RUN apt-get update && \ lsb-release && \ curl -fsSL https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - && \ echo "deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-10 main" >> /etc/apt/sources.list.d/llvm.list && \ - apt-get update && \ + (apt-get update || (sleep 5 && apt-get update)) && \ apt-get install -y --no-install-recommends \ libc6 \ libc6-dev \ diff --git a/pkg/debug/build.Dockerfile b/pkg/debug/build.Dockerfile index 23d8249..da9360e 100644 --- a/pkg/debug/build.Dockerfile +++ b/pkg/debug/build.Dockerfile @@ -7,7 +7,7 @@ FROM ${IMAGE}:${IMAGE_TAG} RUN sed -i '/updates/d' /etc/apt/sources.list && \ sed -i 's/httpredir/archive/' /etc/apt/sources.list -RUN apt-get update && \ +RUN (apt-get update || (sleep 5 && apt-get update)) && \ apt-get install -y --no-install-recommends \ ca-certificates \ curl \ @@ -15,7 +15,7 @@ RUN apt-get update && \ lsb-release && \ curl -fsSL https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - && \ echo "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-10 main" >> /etc/apt/sources.list.d/llvm.list && \ - apt-get update && \ + (apt-get update || (sleep 5 && apt-get update)) && \ apt-get install -y --no-install-recommends --allow-downgrades \ libc6=2.27-3ubuntu1.5 \ libc6-dev \ diff --git a/pkg/docker/Dockerfile b/pkg/docker/Dockerfile index 8702c1b..2a1968f 100644 --- a/pkg/docker/Dockerfile +++ b/pkg/docker/Dockerfile @@ -2,14 +2,14 @@ ARG BUILD_IMAGE="ubuntu" ARG BUILD_IMAGE_TAG="18.04" FROM debian:13 AS files - RUN apt-get update && apt-get install -y ca-certificates nftables + RUN (apt-get update || (sleep 5 && apt-get update)) && apt-get install -y ca-certificates nftables FROM ${BUILD_IMAGE}:${BUILD_IMAGE_TAG} AS builder RUN sed -i '/updates/d' /etc/apt/sources.list && \ sed -i 's/httpredir/archive/' /etc/apt/sources.list -RUN apt-get update && \ +RUN (apt-get update || (sleep 5 && apt-get update)) && \ apt-get install -y --no-install-recommends \ ca-certificates \ curl \ @@ -17,7 +17,7 @@ RUN apt-get update && \ lsb-release && \ curl -fsSL https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - && \ echo "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-10 main" >> /etc/apt/sources.list.d/llvm.list && \ - apt-get update && \ + (apt-get update || (sleep 5 && apt-get update)) && \ apt-get install -y --no-install-recommends --allow-downgrades \ libc6=2.27-3ubuntu1.5 \ libc6-dev \ diff --git a/pkg/docker/build.Dockerfile b/pkg/docker/build.Dockerfile index 7f6eb0e..53d9d86 100644 --- a/pkg/docker/build.Dockerfile +++ b/pkg/docker/build.Dockerfile @@ -6,7 +6,7 @@ FROM ${IMAGE}:${IMAGE_TAG} RUN sed -i '/updates/d' /etc/apt/sources.list && \ sed -i 's/httpredir/archive/' /etc/apt/sources.list -RUN apt-get update && \ +RUN (apt-get update || (sleep 5 && apt-get update)) && \ apt-get install -y --no-install-recommends \ ca-certificates \ curl \ @@ -14,7 +14,7 @@ RUN apt-get update && \ lsb-release && \ curl -fsSL https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - && \ echo "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-10 main" >> /etc/apt/sources.list.d/llvm.list && \ - apt-get update && \ + (apt-get update || (sleep 5 && apt-get update)) && \ apt-get install -y --no-install-recommends --allow-downgrades \ libc6=2.27-3ubuntu1.5 \ libc6-dev \