diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 68ed9e9..250fc94 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -26,6 +26,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ curl \ ca-certificates \ sudo \ + clangd-12 \ && rm -rf /var/lib/apt/lists/* # Build bpftool from source diff --git a/.vscode/c_cpp_properties.json b/.vscode/c_cpp_properties.json new file mode 100644 index 0000000..2754888 --- /dev/null +++ b/.vscode/c_cpp_properties.json @@ -0,0 +1,25 @@ +{ + "configurations": [ + { + "name": "Linux", + "includePath": [ + "${workspaceFolder}/**", + "${workspaceFolder}/src/bpf/include", + "${workspaceFolder}/src/bpf/lib", + "/usr/include", + "/usr/include/x86_64-linux-gnu", + "/usr/include/bpf", + "/usr/local/include/bpf" + ], + "defines": [], + "compilerPath": "/usr/bin/clang", + "cStandard": "c17", + "cppStandard": "c++14", + "intelliSenseMode": "linux-clang-x64", + "compilerArgs": [ + "-Wno-int-conversion" + ] + } + ], + "version": 4 +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 6c9ee27..8915b15 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,12 +1,10 @@ { - // Point rust-analyzer at the actual crate(s) in this workspace "rust-analyzer.linkedProjects": [ - "synapse/Cargo.toml" + "Cargo.toml" ], // Enable all features if your crate uses optional features - "rust-analyzer.cargo.allFeatures": true, // Run clippy on save for quick lint feedback (optional) - "rust-analyzer.checkOnSave.command": "clippy", + "rust-analyzer.checkOnSave": true, // Use the local toolchain when available "rust-analyzer.rustc.source": "discover" -} +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 672b6da..6be7a38 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,21 @@ homepage = "https://gen0sec.com" repository = "https://github.com/gen0sec/synapse" version = "0.4.2" edition = "2024" -keywords = ["bpf", "firewall", "reverse-proxy", "tls", "proxy", "proxy-protocol", "content-scanning", "threat-intelligence", "captcha", "hcaptcha", "recaptcha", "cloudflare-turnstile", "runtime"] +keywords = [ + "bpf", + "firewall", + "reverse-proxy", + "tls", + "proxy", + "proxy-protocol", + "content-scanning", + "threat-intelligence", + "captcha", + "hcaptcha", + "recaptcha", + "cloudflare-turnstile", + "runtime", +] readme = "README.md" [target.'cfg(unix)'.build-dependencies] @@ -25,19 +39,19 @@ tokio = { version = "1", features = [ ] } anyhow = "1" hyper = { version = "1", features = ["http1", "server"] } -hyper-util = { version = "0.1", features = [ - "server", - "tokio", - "http1", -] } +hyper-util = { version = "0.1", features = ["server", "tokio", "http1"] } http-body-util = "0.1" plain = "0.2.3" serde = { version = "1", features = ["derive"] } serde_json = "1" serde_yaml = "0.9" clap = { version = "4.5.54", features = ["derive"] } -nix = { version = "0.31.1", features = ["net", "fs"] } -redis = { version = "1.0", features = ["tokio-native-tls-comp", "connection-manager", "r2d2"]} +nix = { version = "0.31.1", features = ["net", "fs", "resource"] } +redis = { version = "1.0", features = [ + "tokio-native-tls-comp", + "connection-manager", + "r2d2", +] } native-tls = "0.2" tokio-rustls = "0.26.4" rustls = { version = "0.23.36", default-features = false, features = [ @@ -65,15 +79,24 @@ gethostname = "1.1.0" local-ip-address = "0.6.9" flate2 = "1.1" log = "0.4.29" -env_logger = { version = "0.11", default-features = false, features = ["auto-color", "humantime"] } -log4rs = { version = "1.3", features = ["console_appender", "file_appender", "rolling_file_appender", "json_encoder", "gzip"] } +env_logger = { version = "0.11", default-features = false, features = [ + "auto-color", + "humantime", +] } +log4rs = { version = "1.3", features = [ + "console_appender", + "file_appender", + "rolling_file_appender", + "json_encoder", + "gzip", +] } syslog = "7.0" jsonwebtoken = { version = "10.1", features = ["rust_crypto"] } uuid = { version = "1.20", features = ["v4", "serde"] } url = "2.5" clamav-tcp = "0.2" multer = "3.0" -proxy-protocol = { git = "https://github.com/gen0sec/proxy-protocol", rev = "ac28b27d317088f0e9e89805ada3b9f5cfbf5673"} +proxy-protocol = { git = "https://github.com/gen0sec/proxy-protocol", rev = "ac28b27d317088f0e9e89805ada3b9f5cfbf5673" } rand = "0.9" regex = "1.0" daemonize = "0.5.0" @@ -85,13 +108,17 @@ daemonize = "0.5.0" # pingora-http = { path = "../pingora/pingora-http" } # pingora-memory-cache = { path = "../pingora/pingora-memory-cache" } -wirefilter-engine = { git = "https://github.com/gen0sec/wirefilter" , rev = "ab901470a24aad789cb9c03dd214d6c7d4cab589" } -pingora = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81", features = ["lb", "openssl", "proxy"] } -pingora-core = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81"} -pingora-proxy = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81"} -pingora-limits = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81"} -pingora-http = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81"} -pingora-memory-cache = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81"} +wirefilter-engine = { git = "https://github.com/gen0sec/wirefilter", rev = "ab901470a24aad789cb9c03dd214d6c7d4cab589" } +pingora = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81", features = [ + "lb", + "openssl", + "proxy", +] } +pingora-core = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81" } +pingora-proxy = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81" } +pingora-limits = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81" } +pingora-http = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81" } +pingora-memory-cache = { git = "https://github.com/gen0sec/pingora", rev = "c92146d621542303dd9b93a4cb5252e1eef46c81" } # JA4+ fingerprinting library # nstealth = { path = "../nstealth" } diff --git a/bin/act b/bin/act new file mode 100755 index 0000000..8dd1fbd Binary files /dev/null and b/bin/act differ diff --git a/build.rs b/build.rs index e69aed2..28f5e70 100644 --- a/build.rs +++ b/build.rs @@ -9,7 +9,7 @@ use std::path::{Path, PathBuf}; #[cfg(all(unix, feature = "bpf"))] use libbpf_cargo::SkeletonBuilder; -const SRC: &str = "src/security/firewall/bpf/filter.bpf.c"; +const SRC: &str = "src/security/firewall/bpf/xdp.bpf.c"; const HEADER_DIR: &str = "src/security/firewall/bpf"; fn main() { @@ -47,10 +47,14 @@ fn main() { assert!(Path::new(&vmlinux_include).exists(), "vmlinux.h not found"); let bpf_include = Path::new(&env::var("CARGO_MANIFEST_DIR").unwrap()).join(HEADER_DIR); + let include = bpf_include.join("include"); + let lib = bpf_include.join("lib"); + // ✅ Construct full output path in OUT_DIR let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); - let skel_path = out_dir.join("filter.skel.rs"); + let skel_path = out_dir.join("xdp.skel.rs"); + // ✅ Pass the full path to build_and_generate SkeletonBuilder::new() .source(SRC) .clang_args([ @@ -58,14 +62,20 @@ fn main() { vmlinux_include.as_os_str(), OsStr::new("-I"), bpf_include.as_os_str(), - OsStr::new("-O3"), // Max optimizations to reduce program size - OsStr::new("-fno-unroll-loops"), // Reduce instruction count - OsStr::new("-Wall"), - OsStr::new("-Wextra"), - OsStr::new("-DBPF_NO_PRESERVE_ACCESS_INDEX"), // Older clang compat - OsStr::new("-Ubpf"), // Avoid macro collision + OsStr::new("-I"), + include.as_os_str(), + OsStr::new("-I"), + lib.as_os_str(), + OsStr::new("-O3"), // Maximum optimizations to reduce program size + OsStr::new("-fno-unroll-loops"), // Prevent loop unrolling to reduce instruction count + OsStr::new("-Wall"), // Enable all warnings + OsStr::new("-Wextra"), // Extra warnings + OsStr::new("-DBPF_NO_PRESERVE_ACCESS_INDEX"), // Disable preserve_access_index for older clang + OsStr::new("-Ubpf"), // Undefine bpf macro to avoid conflict with struct netns_bpf bpf ]) .build_and_generate(skel_path.to_str().expect("Invalid UTF-8 in path")) .expect("Failed to generate skeleton"); + + println!("✅ Wrote skeleton to: {:?}", skel_path); } } diff --git a/config/config.yaml b/config/config.yaml index 5518cdc..40a9e66 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -28,7 +28,6 @@ # Backward compatibility: AX_ prefix and AX_ARXIGNIS_ prefix are still supported # but will log deprecation warnings. - # Application operating mode # - agent: Only access rules and monitoring (no proxy, pingora disabled) # - proxy: Full reverse proxy functionality (default) @@ -36,53 +35,66 @@ mode: "agent" # Network Configuration network: - # The network interface to attach the XDP program to - iface: "eth0" + # The network interface to attach the XDP program to + iface: "eth0" # Additional network interfaces for XDP attach (overrides iface if set) ifaces: [] - # IP version support mode: "ipv4", "ipv6", or "both" (default: "both") - # Note: XDP requires IPv6 to be enabled at kernel level for attachment, - # even in IPv4-only mode. This is a kernel limitation. When set to "ipv4", - # the system will attempt to enable IPv6 on the interface just for XDP - # attachment (not system-wide), allowing IPv4-only operation elsewhere. - # Options: - # - "ipv4": Process IPv4 packets only (IPv6 still required for XDP attachment) - # - "ipv6": Process IPv6 packets only - # - "both": Process both IPv4 and IPv6 packets (default) - ip_version: "both" + # IP version support mode: "ipv4", "ipv6", or "both" (default: "both") + # Note: XDP requires IPv6 to be enabled at kernel level for attachment, + # even in IPv4-only mode. This is a kernel limitation. When set to "ipv4", + # the system will attempt to enable IPv6 on the interface just for XDP + # attachment (not system-wide), allowing IPv4-only operation elsewhere. + # Options: + # - "ipv4": Process IPv4 packets only (IPv6 still required for XDP attachment) + # - "ipv6": Process IPv6 packets only + # - "both": Process both IPv4 and IPv6 packets (default) + ip_version: "both" # Firewall Configuration firewall: - # Firewall backend mode: auto, xdp, nftables, iptables, none - # - auto: Automatically select best available (XDP > nftables > iptables > none) - # - xdp: Force XDP/BPF backend (highest performance, requires kernel support) - # - nftables: Force nftables backend (requires nft command and kernel support) - # - iptables: Force iptables backend (legacy, most compatible) - # - none: Disable kernel firewall, userland enforcement only - mode: "auto" - - # Disable XDP packet filtering (run without BPF/XDP) - disable_xdp: false + # Firewall backend mode: auto, xdp, nftables, iptables, none + # - auto: Automatically select best available (XDP > nftables > iptables > none) + # - xdp: Force XDP/BPF backend (highest performance, requires kernel support) + # - nftables: Force nftables backend (requires nft command and kernel support) + # - iptables: Force iptables backend (legacy, most compatible) + # - none: Disable kernel firewall, userland enforcement only + mode: "auto" + + # Disable XDP packet filtering (run without BPF/XDP) + disable_xdp: false + + # XDP Rate Limiter Configuration + ratelimiter: + # Enable XDP rate limiter (0 = disabled, 1 = enabled) + status: 0 + # Requests per second per IP + request_per_sec: 1000 + # Burst factor (capacity = request_per_sec * burst_factor) + burst_factor: 3.0 + # Map size for IPv4 addresses (max unique IPs). Requires restart to change. + ipv4_map_size: 50000 + # Map size for IPv6 addresses (max unique IPs). Requires restart to change. + ipv6_map_size: 50000 # Gen0Sec Platform Configuration # Note: 'arxignis' is also accepted for backward compatibility but deprecated platform: - # API key for Gen0Sec service - api_key: "" + # API key for Gen0Sec service + api_key: "" - # Base URL for Gen0Sec API - base_url: "https://api.gen0sec.com/v1" + # Base URL for Gen0Sec API + base_url: "https://api.gen0sec.com/v1" - # Enable sending access logs to platform server - log_sending_enabled: true + # Enable sending access logs to platform server + log_sending_enabled: true - # Include response body in access logs - include_response_body: true + # Include response body in access logs + include_response_body: true - # Maximum size for request/response bodies in access logs (bytes) - Don't override in Basic plan that's the maximum allowed by the plan. - max_body_size: 1048576 + # Maximum size for request/response bodies in access logs (bytes) - Don't override in Basic plan that's the maximum allowed by the plan. + max_body_size: 1048576 # Threat MMDB Configuration (used by platform threat intelligence) threat: @@ -102,8 +114,8 @@ platform: # Logging Configuration logging: - # Log level: error, warn, info, debug, trace - level: "info" + # Log level: error, warn, info, debug, trace + level: "info" # Enable file-based logging with separate files for errors and access logs # When enabled, logs will be written to separate files with automatic rotation and gzip compression @@ -198,23 +210,23 @@ logging: # Daemon Configuration daemon: - # Enable daemon mode (run as background process) - enabled: true + # Enable daemon mode (run as background process) + enabled: true - # PID file path - pid_file: "/var/run/synapse.pid" + # PID file path + pid_file: "/var/run/synapse.pid" # Working directory for daemon working_directory: "/var/lib/synapse" - # User to run daemon as (optional, e.g., "nobody") - user: root + # User to run daemon as (optional, e.g., "nobody") + user: root - # Group to run daemon as (optional, e.g., "daemon") - group: root + # Group to run daemon as (optional, e.g., "daemon") + group: root - # Change ownership of PID file to daemon user/group - chown_pid_file: true + # Change ownership of PID file to daemon user/group + chown_pid_file: true # Proxy Configuration (proxy mode features) # Note: 'pingora' is also accepted for backward compatibility but deprecated @@ -231,9 +243,9 @@ proxy: # TLS suite grade (high, medium, unsafe) tls_grade: "medium" - # Default fallback SSL certificate name (file stem without extension, e.g., "default" for default.crt) - # If not specified, the first valid certificate will be used as default - default_certificate: "default" + # Default fallback SSL certificate name (file stem without extension, e.g., "default" for default.crt) + # If not specified, the first valid certificate will be used as default + default_certificate: "default" # Redis Configuration (for ACME certificate storage) redis: diff --git a/src/cli.rs b/src/cli.rs new file mode 100644 index 0000000..8fc8ab8 --- /dev/null +++ b/src/cli.rs @@ -0,0 +1,1222 @@ +use std::{env, path::PathBuf}; + +use anyhow::Result; +use clap::Parser; +use clap::ValueEnum; +use serde::{Deserialize, Serialize}; + +use crate::waf::actions::captcha::CaptchaProvider; + +/// TLS operating mode +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ValueEnum)] +#[serde(rename_all = "lowercase")] +pub enum TlsMode { + /// TLS is disabled + Disabled, +} + +/// Application operating mode +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ValueEnum)] +#[serde(rename_all = "lowercase")] +pub enum AppMode { + /// Agent mode: Only access rules and monitoring (no proxy) + Agent, + /// Proxy mode: Full reverse proxy functionality + Proxy, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + #[serde(default = "default_mode")] + pub mode: String, + + // Global server options (moved from server section) + #[serde(default)] + pub redis: RedisConfig, + #[serde(default)] + pub network: NetworkConfig, + #[serde(default)] + pub arxignis: Gen0SecConfig, + #[serde(default)] + pub geoip: GeoipConfig, + #[serde(default)] + pub content_scanning: ContentScanningCliConfig, + #[serde(default)] + pub logging: LoggingConfig, + #[serde(default)] + pub bpf_stats: BpfStatsConfig, + #[serde(default)] + pub tcp_fingerprint: TcpFingerprintConfig, + #[serde(default)] + pub daemon: DaemonConfig, + #[serde(default)] + pub pingora: PingoraConfig, + #[serde(default)] + pub acme: AcmeConfig, +} + +fn default_mode() -> String { + "proxy".to_string() +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ProxyProtocolConfig { + #[serde(default = "default_proxy_protocol_enabled")] + pub enabled: bool, + #[serde(default = "default_proxy_protocol_timeout")] + pub timeout_ms: u64, +} + +fn default_proxy_protocol_enabled() -> bool { + false +} +fn default_proxy_protocol_timeout() -> u64 { + 1000 +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HealthCheckConfig { + #[serde(default = "default_health_check_enabled")] + pub enabled: bool, + #[serde(default = "default_health_check_endpoint")] + pub endpoint: String, + #[serde(default = "default_health_check_port")] + pub port: String, + #[serde(default = "default_health_check_methods")] + pub methods: Vec, + #[serde(default = "default_health_check_allowed_cidrs")] + pub allowed_cidrs: Vec, +} + +fn default_health_check_enabled() -> bool { + true +} +fn default_health_check_endpoint() -> String { + "/health".to_string() +} +fn default_health_check_port() -> String { + "0.0.0.0:8080".to_string() +} +fn default_health_check_methods() -> Vec { + vec!["GET".to_string(), "HEAD".to_string()] +} +fn default_health_check_allowed_cidrs() -> Vec { + vec![] +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct RedisConfig { + #[serde(default)] + pub url: String, + #[serde(default)] + pub prefix: String, + /// Redis SSL/TLS configuration + #[serde(default)] + pub ssl: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RedisSslConfig { + /// Path to CA certificate file (PEM format) + pub ca_cert_path: Option, + /// Path to client certificate file (PEM format, optional) + pub client_cert_path: Option, + /// Path to client private key file (PEM format, optional) + pub client_key_path: Option, + /// Skip certificate verification (for testing with self-signed certs) + #[serde(default)] + pub insecure: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct NetworkConfig { + #[serde(default)] + pub iface: String, + #[serde(default)] + pub ifaces: Vec, + #[serde(default)] + pub disable_xdp: bool, + /// IP version support mode: "ipv4", "ipv6", or "both" (default: "both") + /// Note: XDP requires IPv6 to be enabled at kernel level for attachment, + /// even in IPv4-only mode. Set to "ipv4" to skip XDP if IPv6 cannot be enabled. + #[serde(default = "default_ip_version")] + pub ip_version: String, +} + +fn default_ip_version() -> String { + "both".to_string() +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct Gen0SecConfig { + #[serde(default)] + pub api_key: String, + #[serde(default = "default_base_url")] + pub base_url: String, + /// Threat MMDB database configuration + #[serde(default)] + pub threat: GeoipDatabaseConfig, + #[serde(default = "default_log_sending_enabled")] + pub log_sending_enabled: bool, + #[serde(default = "default_include_response_body")] + pub include_response_body: bool, + #[serde(default = "default_max_body_size")] + pub max_body_size: usize, + #[serde(default)] + pub captcha: CaptchaConfig, +} + +fn default_base_url() -> String { + "https://api.gen0sec.com/v1".to_string() +} + +fn default_log_sending_enabled() -> bool { + true +} + +fn default_include_response_body() -> bool { + true +} + +fn default_max_body_size() -> usize { + 1024 * 1024 // 1MB +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct GeoipDatabaseConfig { + /// URL to download GeoIP MMDB file + #[serde(default)] + pub url: String, + /// Optional local path to store/read the GeoIP MMDB file + #[serde(default)] + pub path: Option, + /// Optional custom headers to add to download requests + #[serde(default)] + pub headers: Option>, + /// Refresh interval in seconds (optional, overrides parent refresh_secs if set) + #[serde(default)] + pub refresh_secs: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct GeoipConfig { + /// Country database configuration + #[serde(default)] + pub country: GeoipDatabaseConfig, + /// ASN database configuration + #[serde(default)] + pub asn: GeoipDatabaseConfig, + /// City database configuration + #[serde(default)] + pub city: GeoipDatabaseConfig, + /// How often to refresh the GeoIP MMDB from the remote URL (seconds). Default: 28800 (8 hours) + #[serde(default = "default_geoip_refresh_secs")] + pub refresh_secs: u64, +} + +fn default_geoip_refresh_secs() -> u64 { + 28800 // 8 hours +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ContentScanningCliConfig { + #[serde(default = "default_scanning_enabled")] + pub enabled: bool, + #[serde(default = "default_clamav_server")] + pub clamav_server: String, + #[serde(default = "default_max_file_size")] + pub max_file_size: usize, + #[serde(default)] + pub scan_content_types: Vec, + #[serde(default)] + pub skip_extensions: Vec, + #[serde(default = "default_scan_expression")] + pub scan_expression: String, +} + +fn default_scanning_enabled() -> bool { + false +} +fn default_clamav_server() -> String { + "localhost:3310".to_string() +} +fn default_max_file_size() -> usize { + 10 * 1024 * 1024 +} +fn default_scan_expression() -> String { + "http.request.method eq \"POST\" or http.request.method eq \"PUT\"".to_string() +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct LoggingConfig { + #[serde(default)] + pub level: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct CaptchaConfig { + #[serde(default)] + pub site_key: Option, + #[serde(default)] + pub secret_key: Option, + #[serde(default)] + pub jwt_secret: Option, + #[serde(default)] + pub provider: String, + #[serde(default)] + pub token_ttl: u64, + #[serde(default)] + pub cache_ttl: u64, +} + +impl Config { + pub fn load_from_file(path: &PathBuf) -> Result { + let content = std::fs::read_to_string(path)?; + let config: Config = serde_yaml::from_str(&content)?; + Ok(config) + } + + pub fn default() -> Self { + Self { + mode: "proxy".to_string(), + redis: RedisConfig { + url: "redis://127.0.0.1/0".to_string(), + prefix: "ax:synapse".to_string(), + ssl: None, + }, + network: NetworkConfig { + iface: "eth0".to_string(), + ifaces: vec![], + disable_xdp: false, + ip_version: "both".to_string(), + }, + arxignis: Gen0SecConfig { + api_key: "".to_string(), + base_url: "https://api.gen0sec.com/v1".to_string(), + threat: GeoipDatabaseConfig::default(), + log_sending_enabled: true, + include_response_body: true, + max_body_size: 1024 * 1024, // 1MB + captcha: CaptchaConfig { + site_key: None, + secret_key: None, + jwt_secret: None, + provider: "hcaptcha".to_string(), + token_ttl: 7200, + cache_ttl: 300, + }, + }, + geoip: GeoipConfig { + country: GeoipDatabaseConfig { + url: "".to_string(), + path: None, + headers: None, + refresh_secs: None, + }, + asn: GeoipDatabaseConfig { + url: "".to_string(), + path: None, + headers: None, + refresh_secs: None, + }, + city: GeoipDatabaseConfig { + url: "".to_string(), + path: None, + headers: None, + refresh_secs: None, + }, + refresh_secs: 28800, + }, + content_scanning: ContentScanningCliConfig { + enabled: false, + clamav_server: "localhost:3310".to_string(), + max_file_size: 10 * 1024 * 1024, + scan_content_types: vec![ + "text/html".to_string(), + "application/x-www-form-urlencoded".to_string(), + "multipart/form-data".to_string(), + "application/json".to_string(), + "text/plain".to_string(), + ], + skip_extensions: vec![], + scan_expression: default_scan_expression(), + }, + logging: LoggingConfig { + level: "info".to_string(), + }, + bpf_stats: BpfStatsConfig::default(), + tcp_fingerprint: TcpFingerprintConfig::default(), + daemon: DaemonConfig::default(), + pingora: PingoraConfig::default(), + acme: AcmeConfig::default(), + } + } + + pub fn merge_with_args(&mut self, args: &Args) { + // Override config values with command line arguments if provided + + if !args.ifaces.is_empty() { + self.network.ifaces = args.ifaces.clone(); + } + if let Some(api_key) = &args.arxignis_api_key { + self.arxignis.api_key = api_key.clone(); + } + if !args.arxignis_base_url.is_empty() + && args.arxignis_base_url != "https://api.gen0sec.com/v1" + { + self.arxignis.base_url = args.arxignis_base_url.clone(); + } + if let Some(log_sending_enabled) = args.arxignis_log_sending_enabled { + self.arxignis.log_sending_enabled = log_sending_enabled; + } + self.arxignis.include_response_body = args.arxignis_include_response_body; + self.arxignis.max_body_size = args.arxignis_max_body_size; + if args.captcha_site_key.is_some() { + self.arxignis.captcha.site_key = args.captcha_site_key.clone(); + } + if args.captcha_secret_key.is_some() { + self.arxignis.captcha.secret_key = args.captcha_secret_key.clone(); + } + if args.captcha_jwt_secret.is_some() { + self.arxignis.captcha.jwt_secret = args.captcha_jwt_secret.clone(); + } + if let Some(provider) = &args.captcha_provider { + self.arxignis.captcha.provider = format!("{:?}", provider).to_lowercase(); + } + + // Proxy protocol configuration overrides + // if args.proxy_protocol_enabled { + // self.proxy_protocol.enabled = true; + // } + // if args.proxy_protocol_timeout != 1000 { + // self.proxy_protocol.timeout_ms = args.proxy_protocol_timeout; + // } + + // Daemon configuration overrides + if args.daemon { + self.daemon.enabled = true; + } + if args.daemon_pid_file != "/var/run/synapse.pid" { + self.daemon.pid_file = args.daemon_pid_file.clone(); + } + if args.daemon_working_dir != "/" { + self.daemon.working_directory = args.daemon_working_dir.clone(); + } + if args.daemon_stdout != "/var/log/synapse.out" { + self.daemon.stdout = args.daemon_stdout.clone(); + } + if args.daemon_stderr != "/var/log/synapse.err" { + self.daemon.stderr = args.daemon_stderr.clone(); + } + if args.daemon_user.is_some() { + self.daemon.user = args.daemon_user.clone(); + } + if args.daemon_group.is_some() { + self.daemon.group = args.daemon_group.clone(); + } + + // Redis configuration overrides + if !args.redis_url.is_empty() && args.redis_url != "redis://127.0.0.1/0" { + self.redis.url = args.redis_url.clone(); + } + if !args.redis_prefix.is_empty() && args.redis_prefix != "ax:synapse" { + self.redis.prefix = args.redis_prefix.clone(); + } + } + + pub fn validate_required_fields(&mut self, args: &Args) -> Result<()> { + // Check if arxignis API key is provided - only warn if not provided + // (to support old config format that doesn't have this field) + if args.arxignis_api_key.is_none() && self.arxignis.api_key.is_empty() { + log::warn!("Gen0Sec API key not provided. Some features may not work."); + } + + Ok(()) + } + + pub fn load_from_args(args: &Args) -> Result { + let mut config = if let Some(config_path) = &args.config { + Self::load_from_file(config_path)? + } else { + Self::default() + }; + + config.merge_with_args(args); + config.apply_env_overrides(); + config.validate_required_fields(args)?; + Ok(config) + } + + pub fn apply_env_overrides(&mut self) { + // Mode override + if let Ok(val) = env::var("AX_MODE") { + self.mode = val; + } + + // Redis configuration overrides + if let Ok(val) = env::var("AX_REDIS_URL") { + self.redis.url = val; + } + if let Ok(val) = env::var("AX_REDIS_PREFIX") { + self.redis.prefix = val; + } + + // Redis SSL configuration overrides + // Read all SSL environment variables once + let ca_cert_path = env::var("AX_REDIS_SSL_CA_CERT_PATH").ok(); + let client_cert_path = env::var("AX_REDIS_SSL_CLIENT_CERT_PATH").ok(); + let client_key_path = env::var("AX_REDIS_SSL_CLIENT_KEY_PATH").ok(); + let insecure_val = env::var("AX_REDIS_SSL_INSECURE").ok(); + + // If any SSL env var is set, ensure SSL config exists + if ca_cert_path.is_some() + || client_cert_path.is_some() + || client_key_path.is_some() + || insecure_val.is_some() + { + // Create SSL config if it doesn't exist + if self.redis.ssl.is_none() { + // Parse insecure value if provided, default to false + let insecure_default = insecure_val + .as_ref() + .and_then(|v| v.parse::().ok()) + .unwrap_or(false); + + self.redis.ssl = Some(RedisSslConfig { + ca_cert_path: None, + client_cert_path: None, + client_key_path: None, + insecure: insecure_default, + }); + } + + // Update the SSL config with values from environment variables + let ssl = self + .redis + .ssl + .as_mut() + .expect("SSL config should exist here"); + if let Some(val) = ca_cert_path { + ssl.ca_cert_path = Some(val); + } + if let Some(val) = client_cert_path { + ssl.client_cert_path = Some(val); + } + if let Some(val) = client_key_path { + ssl.client_key_path = Some(val); + } + if let Some(val) = insecure_val { + if let Ok(insecure) = val.parse::() { + ssl.insecure = insecure; + } + } + } + + // Network configuration overrides + if let Ok(val) = env::var("AX_NETWORK_IFACE") { + self.network.iface = val; + } + if let Ok(val) = env::var("AX_NETWORK_IFACES") { + self.network.ifaces = val.split(',').map(|s| s.trim().to_string()).collect(); + } + if let Ok(val) = env::var("AX_NETWORK_DISABLE_XDP") { + self.network.disable_xdp = val.parse().unwrap_or(false); + } + if let Ok(val) = env::var("AX_NETWORK_IP_VERSION") { + // Validate ip_version value + match val.as_str() { + "ipv4" | "ipv6" | "both" => { + self.network.ip_version = val; + } + _ => { + log::warn!( + "Invalid AX_NETWORK_IP_VERSION value '{}', using default 'both'. Valid values: ipv4, ipv6, both", + val + ); + } + } + } + + // Gen0Sec configuration overrides + if let Ok(val) = env::var("AX_ARXIGNIS_API_KEY") { + self.arxignis.api_key = val; + } + if let Ok(val) = env::var("AX_ARXIGNIS_BASE_URL") { + self.arxignis.base_url = val; + } + if let Ok(val) = env::var("AX_ARXIGNIS_LOG_SENDING_ENABLED") { + if let Ok(parsed) = val.parse::() { + self.arxignis.log_sending_enabled = parsed; + } + } + if let Ok(val) = env::var("AX_ARXIGNIS_INCLUDE_RESPONSE_BODY") { + self.arxignis.include_response_body = val.parse().unwrap_or(true); + } + if let Ok(val) = env::var("AX_ARXIGNIS_MAX_BODY_SIZE") { + self.arxignis.max_body_size = val.parse().unwrap_or(1024 * 1024); + } + + // Logging configuration overrides + if let Ok(val) = env::var("AX_LOGGING_LEVEL") { + self.logging.level = val; + } + + // Content scanning overrides + if let Ok(val) = env::var("AX_CONTENT_SCANNING_ENABLED") { + self.content_scanning.enabled = val.parse().unwrap_or(false); + } + if let Ok(val) = env::var("AX_CLAMAV_SERVER") { + self.content_scanning.clamav_server = val; + } + if let Ok(val) = env::var("AX_CONTENT_MAX_FILE_SIZE") { + self.content_scanning.max_file_size = val.parse().unwrap_or(10 * 1024 * 1024); + } + if let Ok(val) = env::var("AX_CONTENT_SCAN_CONTENT_TYPES") { + self.content_scanning.scan_content_types = + val.split(',').map(|s| s.trim().to_string()).collect(); + } + if let Ok(val) = env::var("AX_CONTENT_SKIP_EXTENSIONS") { + self.content_scanning.skip_extensions = + val.split(',').map(|s| s.trim().to_string()).collect(); + } + if let Ok(val) = env::var("AX_CONTENT_SCAN_EXPRESSION") { + self.content_scanning.scan_expression = val; + } + + // Captcha configuration overrides + if let Ok(val) = env::var("AX_CAPTCHA_SITE_KEY") { + self.arxignis.captcha.site_key = Some(val); + } + if let Ok(val) = env::var("AX_CAPTCHA_SECRET_KEY") { + self.arxignis.captcha.secret_key = Some(val); + } + if let Ok(val) = env::var("AX_CAPTCHA_JWT_SECRET") { + self.arxignis.captcha.jwt_secret = Some(val); + } + if let Ok(val) = env::var("AX_CAPTCHA_PROVIDER") { + self.arxignis.captcha.provider = val; + } + if let Ok(val) = env::var("AX_CAPTCHA_TOKEN_TTL") { + self.arxignis.captcha.token_ttl = val.parse().unwrap_or(7200); + } + if let Ok(val) = env::var("AX_CAPTCHA_CACHE_TTL") { + self.arxignis.captcha.cache_ttl = val.parse().unwrap_or(300); + } + + // Proxy protocol configuration overrides + // if let Ok(val) = env::var("AX_PROXY_PROTOCOL_ENABLED") { + // self.proxy_protocol.enabled = val.parse().unwrap_or(false); + // } + // if let Ok(val) = env::var("AX_PROXY_PROTOCOL_TIMEOUT") { + // self.proxy_protocol.timeout_ms = val.parse().unwrap_or(1000); + // } + + // Daemon configuration overrides + if let Ok(val) = env::var("AX_DAEMON_ENABLED") { + self.daemon.enabled = val.parse().unwrap_or(false); + } + if let Ok(val) = env::var("AX_DAEMON_PID_FILE") { + self.daemon.pid_file = val; + } + if let Ok(val) = env::var("AX_DAEMON_WORKING_DIRECTORY") { + self.daemon.working_directory = val; + } + if let Ok(val) = env::var("AX_DAEMON_STDOUT") { + self.daemon.stdout = val; + } + if let Ok(val) = env::var("AX_DAEMON_STDERR") { + self.daemon.stderr = val; + } + if let Ok(val) = env::var("AX_DAEMON_USER") { + self.daemon.user = Some(val); + } + if let Ok(val) = env::var("AX_DAEMON_GROUP") { + self.daemon.group = Some(val); + } + if let Ok(val) = env::var("AX_DAEMON_CHOWN_PID_FILE") { + self.daemon.chown_pid_file = val.parse().unwrap_or(true); + } + } +} + +#[derive(Parser, Debug, Clone)] +#[command(author, version, about, long_about = None)] +pub struct Args { + /// Path to configuration file (YAML format) + #[arg(long, short = 'c')] + pub config: Option, + + /// Path to security rules configuration file (access rules and WAF rules) + /// Used when no Gen0Sec API key is provided as a fallback + #[arg(long, default_value = "security_rules.yaml")] + pub security_rules_config: PathBuf, + + /// Clear a specific certificate from local filesystem and Redis + #[arg(long)] + pub clear_certificate: Option, + + /// Redis connection URL for ACME cache storage. + #[arg(long, default_value = "redis://127.0.0.1/0")] + pub redis_url: String, + + /// Namespace prefix for Redis ACME cache entries. + #[arg(long, default_value = "ax:synapse")] + pub redis_prefix: String, + + /// The network interface to attach the XDP program to. + #[arg(short, long, default_value = "eth0")] + pub iface: String, + + /// Additional network interfaces for XDP attach (comma-separated). If set, overrides --iface. + #[arg(long, value_delimiter = ',', num_args = 0..)] + pub ifaces: Vec, + + #[arg(long)] + pub arxignis_api_key: Option, + + /// Base URL for Gen0Sec API. + #[arg(long, default_value = "https://api.gen0sec.com/v1")] + pub arxignis_base_url: String, + + /// Enable sending access logs to arxignis server + #[arg(long)] + pub arxignis_log_sending_enabled: Option, + + /// Include response body in access logs + #[arg(long, default_value_t = true)] + pub arxignis_include_response_body: bool, + + /// Maximum size for request/response bodies in access logs (bytes) + #[arg(long, default_value = "1048576")] + pub arxignis_max_body_size: usize, + + /// Log level (error, warn, info, debug, trace) + #[arg(long, value_enum, default_value_t = LogLevel::Info)] + pub log_level: LogLevel, + + /// Disable XDP packet filtering (run without BPF/XDP) + #[arg(long, default_value_t = false)] + pub disable_xdp: bool, + + /// Captcha site key for security verification + #[arg(long)] + pub captcha_site_key: Option, + + /// Captcha secret key for security verification + #[arg(long)] + pub captcha_secret_key: Option, + + /// JWT secret key for captcha token signing + #[arg(long)] + pub captcha_jwt_secret: Option, + + /// Captcha provider (hcaptcha, recaptcha, turnstile) + #[arg(long, value_enum)] + pub captcha_provider: Option, + + /// Captcha token TTL in seconds + #[arg(long, default_value = "7200")] + pub captcha_token_ttl: u64, + + /// Captcha validation cache TTL in seconds + #[arg(long, default_value = "300")] + pub captcha_cache_ttl: u64, + + /// Enable PROXY protocol support for TCP connections + #[arg(long, default_value_t = false)] + pub proxy_protocol_enabled: bool, + + /// PROXY protocol timeout in milliseconds + #[arg(long, default_value = "1000")] + pub proxy_protocol_timeout: u64, + + /// Run as daemon in background + #[arg(long, short = 'd', default_value_t = false)] + pub daemon: bool, + + /// PID file path for daemon mode + #[arg(long, default_value = "/var/run/synapse.pid")] + pub daemon_pid_file: String, + + /// Working directory for daemon mode + #[arg(long, default_value = "/")] + pub daemon_working_dir: String, + + /// Stdout log file for daemon mode + #[arg(long, default_value = "/var/log/synapse.out")] + pub daemon_stdout: String, + + /// Stderr log file for daemon mode + #[arg(long, default_value = "/var/log/synapse.err")] + pub daemon_stderr: String, + + /// User to run daemon as + #[arg(long)] + pub daemon_user: Option, + + /// Group to run daemon as + #[arg(long)] + pub daemon_group: Option, +} + +#[derive(Copy, Clone, Debug, ValueEnum)] +pub enum LogLevel { + Error, + Warn, + Info, + Debug, + Trace, +} + +impl LogLevel { + pub fn to_level_filter(self) -> log::LevelFilter { + match self { + LogLevel::Error => log::LevelFilter::Error, + LogLevel::Warn => log::LevelFilter::Warn, + LogLevel::Info => log::LevelFilter::Info, + LogLevel::Debug => log::LevelFilter::Debug, + LogLevel::Trace => log::LevelFilter::Trace, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BpfStatsConfig { + #[serde(default = "default_bpf_stats_enabled")] + pub enabled: bool, + #[serde(default = "default_bpf_stats_log_interval")] + pub log_interval_secs: u64, + #[serde(default = "default_bpf_stats_enable_dropped_ip_events")] + pub enable_dropped_ip_events: bool, + #[serde(default = "default_bpf_stats_dropped_ip_events_interval")] + pub dropped_ip_events_interval_secs: u64, +} + +fn default_bpf_stats_enabled() -> bool { + true +} +fn default_bpf_stats_log_interval() -> u64 { + 60 +} +fn default_bpf_stats_enable_dropped_ip_events() -> bool { + true +} +fn default_bpf_stats_dropped_ip_events_interval() -> u64 { + 30 +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct TcpFingerprintConfig { + #[serde(default = "default_tcp_fingerprint_enabled")] + pub enabled: bool, + #[serde(default = "default_tcp_fingerprint_log_interval")] + pub log_interval_secs: u64, + #[serde(default = "default_tcp_fingerprint_enable_fingerprint_events")] + pub enable_fingerprint_events: bool, + #[serde(default = "default_tcp_fingerprint_events_interval")] + pub fingerprint_events_interval_secs: u64, + #[serde(default = "default_tcp_fingerprint_min_packet_count")] + pub min_packet_count: u32, + #[serde(default = "default_tcp_fingerprint_min_connection_duration")] + pub min_connection_duration_secs: u64, +} + +fn default_tcp_fingerprint_enabled() -> bool { + true +} +fn default_tcp_fingerprint_log_interval() -> u64 { + 60 +} +fn default_tcp_fingerprint_enable_fingerprint_events() -> bool { + true +} +fn default_tcp_fingerprint_events_interval() -> u64 { + 30 +} +fn default_tcp_fingerprint_min_packet_count() -> u32 { + 3 +} +fn default_tcp_fingerprint_min_connection_duration() -> u64 { + 1 +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct DaemonConfig { + #[serde(default = "default_daemon_enabled")] + pub enabled: bool, + #[serde(default = "default_daemon_pid_file")] + pub pid_file: String, + #[serde(default = "default_daemon_working_directory")] + pub working_directory: String, + #[serde(default = "default_daemon_stdout")] + pub stdout: String, + #[serde(default = "default_daemon_stderr")] + pub stderr: String, + pub user: Option, + pub group: Option, + #[serde(default = "default_daemon_chown_pid_file")] + pub chown_pid_file: bool, +} + +fn default_daemon_enabled() -> bool { + false +} +fn default_daemon_pid_file() -> String { + "/var/run/synapse.pid".to_string() +} +fn default_daemon_working_directory() -> String { + "/".to_string() +} +fn default_daemon_stdout() -> String { + "/var/log/synapse.out".to_string() +} +fn default_daemon_stderr() -> String { + "/var/log/synapse.err".to_string() +} +fn default_daemon_chown_pid_file() -> bool { + true +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PingoraConfig { + #[serde(default)] + pub proxy_address_http: String, + #[serde(default)] + pub proxy_address_tls: Option, + #[serde(default)] + pub proxy_certificates: Option, + #[serde(default = "default_pingora_tls_grade")] + pub proxy_tls_grade: String, + #[serde(default)] + pub default_certificate: Option, + #[serde(default)] + pub upstreams_conf: String, + #[serde(default)] + pub config_address: String, + #[serde(default = "default_pingora_config_api_enabled")] + pub config_api_enabled: bool, + #[serde(default)] + pub master_key: String, + #[serde(default = "default_pingora_log_level")] + pub log_level: String, + #[serde(default = "default_pingora_healthcheck_method")] + pub healthcheck_method: String, + #[serde(default = "default_pingora_healthcheck_interval")] + pub healthcheck_interval: u16, + #[serde(default)] + pub proxy_protocol: ProxyProtocolConfig, +} + +fn default_pingora_tls_grade() -> String { + "medium".to_string() +} +fn default_pingora_config_api_enabled() -> bool { + true +} +fn default_pingora_log_level() -> String { + "debug".to_string() +} +fn default_pingora_healthcheck_method() -> String { + "HEAD".to_string() +} +fn default_pingora_healthcheck_interval() -> u16 { + 2 +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AcmeConfig { + /// Enable embedded ACME server + #[serde(default = "default_acme_enabled")] + pub enabled: bool, + /// Port for ACME server (e.g., 9180) + #[serde(default = "default_acme_port")] + pub port: u16, + /// Email for ACME account + pub email: Option, + /// Storage path for certificates + #[serde(default = "default_acme_storage_path")] + pub storage_path: String, + /// Storage type: "file" or "redis" (defaults to "file", or "redis" if redis_url is set) + pub storage_type: Option, + /// Use development/staging ACME server + #[serde(default)] + pub development: bool, + /// Redis URL for storage (optional, uses global redis.url if not set) + pub redis_url: Option, +} + +impl Default for AcmeConfig { + fn default() -> Self { + Self { + enabled: default_acme_enabled(), + port: default_acme_port(), + email: None, + storage_path: default_acme_storage_path(), + storage_type: None, + development: false, + redis_url: None, + } + } +} + +fn default_acme_enabled() -> bool { + false +} +fn default_acme_port() -> u16 { + 9180 +} +fn default_acme_storage_path() -> String { + "/tmp/synapse-acme".to_string() +} + +impl PingoraConfig { + /// Convert PingoraConfig to AppConfig for compatibility with old proxy system + pub fn to_app_config(&self) -> crate::utils::structs::AppConfig { + let mut app_config = crate::utils::structs::AppConfig::default(); + app_config.proxy_address_http = self.proxy_address_http.clone(); + app_config.proxy_address_tls = self.proxy_address_tls.clone(); + app_config.proxy_certificates = self.proxy_certificates.clone(); + app_config.proxy_tls_grade = Some(self.proxy_tls_grade.clone()); + app_config.default_certificate = self.default_certificate.clone(); + app_config.upstreams_conf = self.upstreams_conf.clone(); + app_config.config_address = self.config_address.clone(); + app_config.config_api_enabled = self.config_api_enabled; + app_config.master_key = self.master_key.clone(); + app_config.healthcheck_method = self.healthcheck_method.clone(); + app_config.healthcheck_interval = self.healthcheck_interval; + app_config.proxy_protocol_enabled = self.proxy_protocol.enabled; + + // Parse config_address to local_server + if let Some((ip, port_str)) = self.config_address.split_once(':') { + if let Ok(port) = port_str.parse::() { + app_config.local_server = Some((ip.to_string(), port)); + } + } + + // Parse proxy_address_tls to proxy_port_tls + if let Some(ref tls_addr) = self.proxy_address_tls { + if let Some((_, port_str)) = tls_addr.split_once(':') { + if let Ok(port) = port_str.parse::() { + app_config.proxy_port_tls = Some(port); + } + } + } + + app_config + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serial_test::serial; + use std::env; + + #[test] + fn test_redis_ssl_config_deserialize() { + let yaml = r#" +ca_cert_path: "/path/to/ca.crt" +client_cert_path: "/path/to/client.crt" +client_key_path: "/path/to/client.key" +insecure: true +"#; + let config: RedisSslConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.ca_cert_path, Some("/path/to/ca.crt".to_string())); + assert_eq!( + config.client_cert_path, + Some("/path/to/client.crt".to_string()) + ); + assert_eq!( + config.client_key_path, + Some("/path/to/client.key".to_string()) + ); + assert!(config.insecure); + } + + #[test] + fn test_redis_ssl_config_deserialize_minimal() { + let yaml = r#" +insecure: false +"#; + let config: RedisSslConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.ca_cert_path, None); + assert_eq!(config.client_cert_path, None); + assert_eq!(config.client_key_path, None); + assert!(!config.insecure); + } + + #[test] + fn test_redis_config_with_ssl() { + let yaml = r#" +url: "rediss://localhost:6379" +prefix: "test:prefix" +ssl: + ca_cert_path: "/path/to/ca.crt" + insecure: false +"#; + let config: RedisConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.url, "rediss://localhost:6379"); + assert_eq!(config.prefix, "test:prefix"); + assert!(config.ssl.is_some()); + let ssl = config.ssl.unwrap(); + assert_eq!(ssl.ca_cert_path, Some("/path/to/ca.crt".to_string())); + assert!(!ssl.insecure); + } + + #[test] + fn test_redis_config_without_ssl() { + let yaml = r#" +url: "redis://localhost:6379" +prefix: "test:prefix" +"#; + let config: RedisConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.url, "redis://localhost:6379"); + assert_eq!(config.prefix, "test:prefix"); + assert!(config.ssl.is_none()); + } + + #[test] + fn test_redis_config_default() { + let config = RedisConfig::default(); + assert_eq!(config.url, ""); + assert_eq!(config.prefix, ""); + assert!(config.ssl.is_none()); + } + + // Helper function to clean up SSL environment variables for test isolation + fn cleanup_redis_ssl_env_vars() { + unsafe { + env::remove_var("AX_REDIS_SSL_CA_CERT_PATH"); + env::remove_var("AX_REDIS_SSL_CLIENT_CERT_PATH"); + env::remove_var("AX_REDIS_SSL_CLIENT_KEY_PATH"); + env::remove_var("AX_REDIS_SSL_INSECURE"); + } + } + + #[test] + #[serial] + fn test_apply_env_overrides_redis_ssl_ca_cert() { + cleanup_redis_ssl_env_vars(); + + let mut config = Config::default(); + unsafe { + env::set_var("AX_REDIS_SSL_CA_CERT_PATH", "/test/ca.crt"); + } + + config.apply_env_overrides(); + + assert!(config.redis.ssl.is_some()); + assert_eq!( + config.redis.ssl.as_ref().unwrap().ca_cert_path, + Some("/test/ca.crt".to_string()) + ); + + unsafe { + env::remove_var("AX_REDIS_SSL_CA_CERT_PATH"); + } + } + + #[test] + #[serial] + fn test_apply_env_overrides_redis_ssl_client_cert() { + cleanup_redis_ssl_env_vars(); + + let mut config = Config::default(); + unsafe { + env::set_var("AX_REDIS_SSL_CLIENT_CERT_PATH", "/test/client.crt"); + env::set_var("AX_REDIS_SSL_CLIENT_KEY_PATH", "/test/client.key"); + } + + config.apply_env_overrides(); + + assert!(config.redis.ssl.is_some()); + let ssl = config.redis.ssl.as_ref().unwrap(); + assert_eq!(ssl.client_cert_path, Some("/test/client.crt".to_string())); + assert_eq!(ssl.client_key_path, Some("/test/client.key".to_string())); + + unsafe { + env::remove_var("AX_REDIS_SSL_CLIENT_CERT_PATH"); + env::remove_var("AX_REDIS_SSL_CLIENT_KEY_PATH"); + } + } + + #[test] + #[serial] + fn test_apply_env_overrides_redis_ssl_insecure() { + cleanup_redis_ssl_env_vars(); + + let mut config = Config::default(); + unsafe { + env::set_var("AX_REDIS_SSL_INSECURE", "true"); + } + + config.apply_env_overrides(); + + assert!(config.redis.ssl.is_some()); + assert!(config.redis.ssl.as_ref().unwrap().insecure); + + unsafe { + env::remove_var("AX_REDIS_SSL_INSECURE"); + } + } + + #[test] + #[serial] + fn test_apply_env_overrides_redis_ssl_insecure_false() { + cleanup_redis_ssl_env_vars(); + + let mut config = Config::default(); + unsafe { + env::set_var("AX_REDIS_SSL_INSECURE", "false"); + } + + config.apply_env_overrides(); + + assert!(config.redis.ssl.is_some()); + assert!(!config.redis.ssl.as_ref().unwrap().insecure); + + unsafe { + env::remove_var("AX_REDIS_SSL_INSECURE"); + } + } + + #[test] + #[serial] + fn test_apply_env_overrides_redis_ssl_combined() { + cleanup_redis_ssl_env_vars(); + let mut config = Config::default(); + unsafe { + env::set_var("AX_REDIS_SSL_CA_CERT_PATH", "/test/ca.crt"); + env::set_var("AX_REDIS_SSL_CLIENT_CERT_PATH", "/test/client.crt"); + env::set_var("AX_REDIS_SSL_CLIENT_KEY_PATH", "/test/client.key"); + env::set_var("AX_REDIS_SSL_INSECURE", "true"); + } + + config.apply_env_overrides(); + + assert!(config.redis.ssl.is_some()); + let ssl = config.redis.ssl.as_ref().unwrap(); + assert_eq!(ssl.ca_cert_path, Some("/test/ca.crt".to_string())); + assert_eq!(ssl.client_cert_path, Some("/test/client.crt".to_string())); + assert_eq!(ssl.client_key_path, Some("/test/client.key".to_string())); + assert!(ssl.insecure); + + unsafe { + env::remove_var("AX_REDIS_SSL_CA_CERT_PATH"); + env::remove_var("AX_REDIS_SSL_CLIENT_CERT_PATH"); + env::remove_var("AX_REDIS_SSL_CLIENT_KEY_PATH"); + env::remove_var("AX_REDIS_SSL_INSECURE"); + } + } + + #[test] + fn test_redis_ssl_config_default_insecure() { + let config = RedisSslConfig { + ca_cert_path: None, + client_cert_path: None, + client_key_path: None, + insecure: false, + }; + assert!(!config.insecure); + } +} diff --git a/src/core/app_state.rs b/src/core/app_state.rs index cef94e6..7257021 100644 --- a/src/core/app_state.rs +++ b/src/core/app_state.rs @@ -5,7 +5,7 @@ use std::sync::{Arc, Mutex}; #[derive(Clone)] pub struct AppState { - pub skels: Vec>>, + pub skels: Vec>>>, pub ifindices: Vec, pub bpf_stats_collector: BpfStatsCollector, pub tcp_fingerprint_collector: TcpFingerprintCollector, diff --git a/src/core/cli.rs b/src/core/cli.rs index 9d1babb..a19f755 100644 --- a/src/core/cli.rs +++ b/src/core/cli.rs @@ -142,6 +142,48 @@ pub struct NetworkConfig { pub ip_version: String, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RateLimiterConfig { + #[serde(default = "default_status")] + pub enabled: bool, + #[serde(default = "default_request_per_sec")] + pub request_per_sec: u64, + #[serde(default = "default_burst_factor")] + pub burst_factor: f32, + #[serde(default = "default_ratelimit_map_size")] + pub ipv4_map_size: u32, + #[serde(default = "default_ratelimit_map_size")] + pub ipv6_map_size: u32, +} + +fn default_status() -> bool { + return true; +} + +fn default_request_per_sec() -> u64 { + 1000 +} + +fn default_burst_factor() -> f32 { + 1.5 +} + +fn default_ratelimit_map_size() -> u32 { + 50000 +} + +impl Default for RateLimiterConfig { + fn default() -> Self { + Self { + enabled: true, + request_per_sec: default_request_per_sec(), + burst_factor: default_burst_factor(), + ipv4_map_size: default_ratelimit_map_size(), + ipv6_map_size: default_ratelimit_map_size(), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct FirewallConfig { /// Firewall backend mode: auto, xdp, nftables, iptables, none @@ -149,6 +191,8 @@ pub struct FirewallConfig { pub mode: crate::firewall::FirewallMode, #[serde(default)] pub disable_xdp: bool, + #[serde(default)] + pub ratelimiter: RateLimiterConfig, } fn default_ip_version() -> String { @@ -382,6 +426,7 @@ impl Config { firewall: FirewallConfig { mode: crate::firewall::FirewallMode::default(), disable_xdp: false, + ratelimiter: RateLimiterConfig::default(), }, platform: Gen0SecConfig { api_key: "".to_string(), diff --git a/src/logger/bpf_stats.rs b/src/logger/bpf_stats.rs index 333b514..cadf8cf 100644 --- a/src/logger/bpf_stats.rs +++ b/src/logger/bpf_stats.rs @@ -4,9 +4,9 @@ use libbpf_rs::MapCore; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::net::{Ipv4Addr, Ipv6Addr}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; -use crate::security::firewall::bpf::FilterSkel; +use crate::security::firewall::bpf::XdpSkel; /// BPF statistics collected from kernel-level access rule enforcement #[derive(Debug, Clone, Serialize, Deserialize)] @@ -70,7 +70,7 @@ pub struct DroppedIpEvents { impl BpfAccessStats { /// Create a new statistics snapshot from BPF maps - pub fn from_bpf_maps(skel: &FilterSkel) -> Result> { + pub fn from_bpf_maps(skel: &XdpSkel) -> Result> { let timestamp = Utc::now(); // Read statistics from BPF maps @@ -138,7 +138,7 @@ impl BpfAccessStats { /// Collect dropped IP addresses from BPF maps fn collect_dropped_ip_addresses( - skel: &FilterSkel, + skel: &XdpSkel, ) -> Result> { let mut ipv4_addresses = HashMap::new(); let mut ipv6_addresses = HashMap::new(); @@ -457,13 +457,13 @@ impl DroppedIpEvents { /// Statistics collector for BPF access rules #[derive(Clone)] pub struct BpfStatsCollector { - skels: Vec>>, + skels: Vec>>>, enabled: bool, } impl BpfStatsCollector { /// Create a new statistics collector - pub fn new(skels: Vec>>, enabled: bool) -> Self { + pub fn new(skels: Vec>>>, enabled: bool) -> Self { Self { skels, enabled } } @@ -485,7 +485,15 @@ impl BpfStatsCollector { let mut stats = Vec::new(); for skel in &self.skels { - match BpfAccessStats::from_bpf_maps(skel) { + let skel_guard = match skel.lock() { + Ok(guard) => guard, + Err(e) => { + log::warn!("Failed to lock BPF skeleton: {}", e); + continue; + } + }; + + match BpfAccessStats::from_bpf_maps(&*skel_guard) { Ok(stat) => stats.push(stat), Err(e) => { log::warn!("Failed to collect BPF stats from skeleton: {}", e); @@ -608,7 +616,16 @@ impl BpfStatsCollector { let mut events = DroppedIpEvents::new(); for skel in &self.skels { - let dropped_ips = BpfAccessStats::collect_dropped_ip_addresses(skel)?; + let dropped_ips = { + let skel_guard = match skel.lock() { + Ok(guard) => guard, + Err(e) => { + return Err(format!("Failed to lock BPF skeleton: {}", e).into()); + } + }; + + BpfAccessStats::collect_dropped_ip_addresses(&*skel_guard)? + }; // Convert IPv4 addresses to events for (ip_str, count) in dropped_ips.ipv4_addresses { @@ -682,8 +699,16 @@ impl BpfStatsCollector { log::debug!("Resetting dropped IP address counters"); for skel in &self.skels { + let skel_guard = match skel.lock() { + Ok(guard) => guard, + Err(e) => { + log::warn!("Failed to lock BPF skeleton for counter reset: {}", e); + continue; + } + }; + // Reset IPv4 counters - match skel.maps.dropped_ipv4_addresses.lookup_batch( + match skel_guard.maps.dropped_ipv4_addresses.lookup_batch( 1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY, @@ -693,7 +718,7 @@ impl BpfStatsCollector { for (key_bytes, _) in batch_iter { if key_bytes.len() >= 4 { let zero_count = 0u64.to_le_bytes(); - if let Err(e) = skel.maps.dropped_ipv4_addresses.update( + if let Err(e) = skel_guard.maps.dropped_ipv4_addresses.update( &key_bytes, &zero_count, libbpf_rs::MapFlags::ANY, @@ -712,7 +737,7 @@ impl BpfStatsCollector { } // Reset IPv6 counters - match skel.maps.dropped_ipv6_addresses.lookup_batch( + match skel_guard.maps.dropped_ipv6_addresses.lookup_batch( 1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY, @@ -722,7 +747,7 @@ impl BpfStatsCollector { for (key_bytes, _) in batch_iter { if key_bytes.len() >= 16 { let zero_count = 0u64.to_le_bytes(); - if let Err(e) = skel.maps.dropped_ipv6_addresses.update( + if let Err(e) = skel_guard.maps.dropped_ipv6_addresses.update( &key_bytes, &zero_count, libbpf_rs::MapFlags::ANY, diff --git a/src/main.rs b/src/main.rs index c45569f..27d468b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,6 +11,7 @@ use daemonize::Daemonize; use libbpf_rs::skel::{OpenSkel, SkelBuilder}; #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] use nix::net::if_::if_nametoindex; +use nix::sys::resource::{Resource, setrlimit}; #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; @@ -346,7 +347,7 @@ async fn async_main(args: Args, config: Config) -> Result<()> { use std::sync::Mutex; #[allow(unused_mut)] - let mut skels: Vec>> = Vec::new(); + let mut skels: Vec>>> = Vec::new(); #[allow(unused_mut)] let mut ifindices: Vec = Vec::new(); let mut firewall_backend = FirewallBackend::None; @@ -361,76 +362,125 @@ async fn async_main(args: Args, config: Config) -> Result<()> { if config.firewall.disable_xdp { log::warn!("XDP disabled by config, will use nftables fallback"); } else { - #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] - { - // Suppress libbpf output - errors will be handled gracefully with fallback to nftables/iptables - libbpf_rs::set_print(None); - - for iface in &iface_names { - let boxed_open: Box> = - Box::new(MaybeUninit::uninit()); - let open_object: &'static mut MaybeUninit = - Box::leak(boxed_open); - let skel_builder = bpf::FilterSkelBuilder::default(); - match skel_builder.open(open_object).and_then(|o| o.load()) { - Ok(mut skel) => { - let ifindex = match if_nametoindex(iface.as_str()) { - Ok(index) => index as i32, - Err(e) => { - log::error!("failed to get interface index for '{}': {e}", iface); - continue; - } - }; - match bpf_attach_to_xdp( - &mut skel, - ifindex, - Some(iface.as_str()), - &config.network.ip_version, - ) { - Ok(mode) => { - xdp_modes.push((iface.as_str(), mode.as_str())); - skels.push(Arc::new(skel)); - ifindices.push(ifindex); + for iface in &iface_names { + let boxed_open: Box> = + Box::new(MaybeUninit::uninit()); + let open_object: &'static mut MaybeUninit = + Box::leak(boxed_open); + let skel_builder = bpf::XdpSkelBuilder::default(); + match skel_builder.open(open_object) { + Ok(mut open_skel) => { + // Set map sizes from config before loading + if let Err(e) = crate::security::ratelimiter::set_bucket_map_size_ipv4( + &mut open_skel, + config.firewall.ratelimiter.ipv4_map_size, + ) { + log::warn!( + "Failed to set IPv4 bucket map size for '{}': {} -> falling back to default", + iface, + e + ); + } else { + log::debug!( + "Successfully set ratelimiter ipv4 map size: {}", + config.firewall.ratelimiter.ipv4_map_size + ); + } + + if let Err(e) = crate::security::ratelimiter::set_bucket_map_size_ipv6( + &mut open_skel, + config.firewall.ratelimiter.ipv6_map_size, + ) { + log::warn!( + "Failed to set IPv6 bucket map size for '{}': {} -> falling back to default", + iface, + e + ); + } else { + log::debug!( + "Successfully set ratelimiter ipv6 map size: {}", + config.firewall.ratelimiter.ipv6_map_size + ); + } + + match open_skel.load() { + Ok(mut skel) => { + { + // Apply rate limiter configuration + let mut ratelimiter = + crate::security::ratelimiter::XDPRateLimit::new(&mut skel); + ratelimiter.set_request_per_sec( + config.firewall.ratelimiter.request_per_sec, + config.firewall.ratelimiter.burst_factor, + ); + + ratelimiter + .set_ratelimiter_status(!!config.firewall.ratelimiter.enabled); } - Err(e) => { - // Check if error is EAFNOSUPPORT (error 97) - IPv6 might be disabled - let error_str = e.to_string(); - if error_str.contains("97") - || error_str.contains("Address family not supported") - { - log::warn!( - "Failed to attach XDP to '{}': {} (IPv6 disabled)", - iface, - e + + let ifindex = match if_nametoindex(iface.as_str()) { + Ok(index) => index as i32, + Err(e) => { + log::error!( + "failed to get interface index for '{}': {e}", + iface ); - } else { - log::error!("Failed to attach XDP to '{}': {}", iface, e); + continue; + } + }; + let mut flags = libbpf_rs::XdpFlags::DRV_MODE; + if iface.starts_with("lo") { + flags = libbpf_rs::XdpFlags::SKB_MODE; + log::warn!("Forcing SKB mode for loopback interface"); + } + match bpf_attach_to_xdp( + &mut skel, + ifindex, + Some(iface.as_str()), + &config.network.ip_version, + ) { + Ok(mode) => { + xdp_modes.push((iface.as_str(), mode.as_str())); + skels.push(Arc::new(Mutex::new(skel))); + ifindices.push(ifindex); + } + Err(e) => { + // Check if error is EAFNOSUPPORT (error 97) - IPv6 might be disabled + let error_str = e.to_string(); + if error_str.contains("97") + || error_str.contains("Address family not supported") + { + log::warn!( + "Failed to attach XDP to '{}': {} (IPv6 disabled)", + iface, + e + ); + } else { + log::error!("Failed to attach XDP to '{}': {}", iface, e); + } } } } - } - Err(e) => { - // Check for common BPF/kernel support errors - let error_str = e.to_string(); - if error_str.contains("Function not implemented") - || error_str.contains("ENOSYS") - || error_str.contains("error 38") - || error_str.contains("CONFIG_BPF_SYSCALL") - { - log::warn!("BPF not supported by kernel for '{}': {}", iface, e); - } else { - log::warn!("failed to load BPF skeleton for '{}': {e}", iface); + Err(e) => { + // Check for common BPF/kernel support errors + let error_str = e.to_string(); + if error_str.contains("Function not implemented") + || error_str.contains("ENOSYS") + || error_str.contains("error 38") + || error_str.contains("CONFIG_BPF_SYSCALL") + { + log::warn!("BPF not supported by kernel for '{}': {}", iface, e); + } else { + log::warn!("failed to load BPF skeleton for '{}': {e}", iface); + } } } } + Err(e) => { + log::warn!("failed to open BPF skeleton for '{}': {e}", iface); + } } } - - #[cfg(any(not(feature = "bpf"), feature = "disable-bpf"))] - { - log::warn!("BPF support disabled at build time; skipping XDP attachment"); - log::warn!("Set 'disable_xdp: true' in config.yaml to suppress this warning"); - } } // Determine firewall backend based on config and availability diff --git a/src/proxy/proxyhttp.rs b/src/proxy/proxyhttp.rs index 17befc2..5c330cf 100644 --- a/src/proxy/proxyhttp.rs +++ b/src/proxy/proxyhttp.rs @@ -464,10 +464,18 @@ impl ProxyHttp for LB { if let Some(rate_limit_config) = &waf_result.rate_limit_config { let client_ip = peer_addr.ip().to_string(); + + // Convert worker::config::RateLimitConfig to security::waf::actions::rate_limit::RateLimitConfig + let action_config = rate_limit::RateLimitConfig { + period: rate_limit_config.period.clone(), + duration: rate_limit_config.duration.clone(), + requests: rate_limit_config.requests.clone(), + }; + let result = rate_limit::check_rate_limit( &waf_result.rule_id, &client_ip, - rate_limit_config, + &action_config, ); if result.exceeded { diff --git a/src/security/access_rules.rs b/src/security/access_rules.rs index ded01ab..64e00ab 100644 --- a/src/security/access_rules.rs +++ b/src/security/access_rules.rs @@ -3,19 +3,25 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::str::FromStr; use std::sync::{Arc, Mutex, OnceLock}; -use crate::firewall::{Firewall, IptablesFirewall, NftablesFirewall, SYNAPSEFirewall}; +use tracing::info; + +use crate::security::firewall::{Firewall, IptablesFirewall, NftablesFirewall, SYNAPSEFirewall}; +use crate::security::ratelimiter::XDPRateLimit; use crate::security::waf::wirefilter::update_http_filter_from_config_value; use crate::utils::http_utils::{is_ip_in_cidr, parse_ip_or_cidr}; -use crate::worker::config; use crate::worker::config::global_config; +use crate::worker::config::{self, XDPRateLimitConfig}; // Store previous rules state for comparison type PreviousRules = Arc>>; type PreviousRulesV6 = Arc>>; +type PreviousXDPRateLimiterConfig = Arc>; + // Global previous rules state for the access rules worker static PREVIOUS_RULES: OnceLock = OnceLock::new(); static PREVIOUS_RULES_V6: OnceLock = OnceLock::new(); +static PREVIOUS_RATE_LIMITER_CONFIG: OnceLock = OnceLock::new(); fn get_previous_rules() -> &'static PreviousRules { PREVIOUS_RULES.get_or_init(|| Arc::new(Mutex::new(HashSet::new()))) @@ -25,9 +31,19 @@ fn get_previous_rules_v6() -> &'static PreviousRulesV6 { PREVIOUS_RULES_V6.get_or_init(|| Arc::new(Mutex::new(HashSet::new()))) } +fn get_previous_xd_rate_limiter_config() -> &'static PreviousXDPRateLimiterConfig { + PREVIOUS_RATE_LIMITER_CONFIG.get_or_init(|| { + Arc::new(Mutex::new(XDPRateLimitConfig { + enabled: false, + requests: 1000, + burst_factor: 1.0, + })) + }) +} + /// Apply access rules once using the current global config snapshot (initial setup) pub fn init_access_rules_from_global( - skels: &Vec>>, + skels: &Vec>>>, ) -> Result<(), Box> { if skels.is_empty() { return Ok(()); @@ -38,12 +54,24 @@ pub fn init_access_rules_from_global( Arc::new(Mutex::new(std::collections::HashSet::new())); let previous_rules_v6: PreviousRulesV6 = Arc::new(Mutex::new(std::collections::HashSet::new())); + let previous_rate_limiter_config: PreviousXDPRateLimiterConfig = + Arc::new(Mutex::new(XDPRateLimitConfig { + enabled: false, + requests: 1000, + burst_factor: 1.0, + })); // Use Arc clone instead of full Config clone for efficiency let resp = config::ConfigApiResponse { success: true, config: (**cfg_arc).clone(), }; - apply_rules(skels, &resp, &previous_rules, &previous_rules_v6)?; + apply_rules( + skels, + &resp, + &previous_rules, + &previous_rules_v6, + &previous_rate_limiter_config, + )?; } } Ok(()) @@ -53,9 +81,10 @@ pub fn init_access_rules_from_global( /// This is called periodically by the ConfigWorker after it fetches new config /// Set `skip_waf_update` to true in agent mode to skip WAF wirefilter updates pub fn apply_rules_from_global( - skels: &Vec>>, + skels: &Vec>>>, previous_rules: &PreviousRules, previous_rules_v6: &PreviousRulesV6, + previous_rate_limiter_config: &PreviousXDPRateLimiterConfig, skip_waf_update: bool, ) -> Result<(), Box> { // Read from global config and apply if available @@ -79,6 +108,7 @@ pub fn apply_rules_from_global( }, previous_rules, previous_rules_v6, + previous_rate_limiter_config, )?; return Ok(()); } @@ -90,13 +120,14 @@ pub fn apply_rules_from_global( /// This is called by the ConfigWorker after it fetches new config /// Set `skip_waf_update` to true in agent mode to skip WAF wirefilter updates pub fn apply_rules_from_global_with_state( - skels: &Vec>>, + skels: &Vec>>>, skip_waf_update: bool, ) -> Result<(), Box> { apply_rules_from_global( skels, get_previous_rules(), get_previous_rules_v6(), + get_previous_xd_rate_limiter_config(), skip_waf_update, ) } @@ -468,7 +499,7 @@ fn apply_rules_to_iptables( "iptables: IPv4 rules updated (+{} -{} total={})", to_add.len(), to_remove.len(), - current_rules.len() + current_rules.len(), ); } if !to_add_v6.is_empty() || !to_remove_v6.is_empty() { @@ -476,7 +507,7 @@ fn apply_rules_to_iptables( "iptables: IPv6 rules updated (+{} -{} total={})", to_add_v6.len(), to_remove_v6.len(), - current_rules_v6.len() + current_rules_v6.len(), ); } } @@ -493,10 +524,11 @@ fn apply_rules_to_iptables( } fn apply_rules( - skels: &Vec>>, + skels: &Vec>>>, resp: &config::ConfigApiResponse, previous_rules: &PreviousRules, previous_rules_v6: &PreviousRulesV6, + previous_rules_rate_limiter: &PreviousXDPRateLimiterConfig, ) -> Result<(), Box> { fn parse_ipv4_ip_or_cidr(entry: &str) -> Option<(Ipv4Addr, u32)> { let s = entry.trim(); @@ -633,72 +665,108 @@ fn apply_rules( // Compare with previous rules to detect changes let mut previous_rules_guard = previous_rules.lock().unwrap(); let mut previous_rules_v6_guard = previous_rules_v6.lock().unwrap(); + let mut previous_rate_limiter_config = previous_rules_rate_limiter.lock().unwrap(); // Check if rules have changed let ipv4_changed = *previous_rules_guard != current_rules; let ipv6_changed = *previous_rules_v6_guard != current_rules_v6; + let rate_limiter_config_changed = *previous_rate_limiter_config != rule.config.rate_limit; // If neither family changed, skip quietly with a single log entry - if !ipv4_changed && !ipv6_changed { - log::debug!("No IPv4 or IPv6 access rule changes detected, skipping BPF map updates"); - return Ok(()); - } - - // Compute diffs once against snapshots - let prev_v4_snapshot = previous_rules_guard.clone(); - let prev_v6_snapshot = previous_rules_v6_guard.clone(); - let removed_v4: Vec<(Ipv4Addr, u32)> = prev_v4_snapshot - .difference(¤t_rules) - .cloned() - .collect(); - let added_v4: Vec<(Ipv4Addr, u32)> = current_rules - .difference(&prev_v4_snapshot) - .cloned() - .collect(); - let removed_v6: Vec<(Ipv6Addr, u32)> = prev_v6_snapshot - .difference(¤t_rules_v6) - .cloned() - .collect(); - let added_v6: Vec<(Ipv6Addr, u32)> = current_rules_v6 - .difference(&prev_v6_snapshot) - .cloned() - .collect(); + if ipv4_changed || ipv6_changed { + // Compute diffs once against snapshots + let prev_v4_snapshot = previous_rules_guard.clone(); + let prev_v6_snapshot = previous_rules_v6_guard.clone(); + let removed_v4: Vec<(Ipv4Addr, u32)> = prev_v4_snapshot + .difference(¤t_rules) + .cloned() + .collect(); + let added_v4: Vec<(Ipv4Addr, u32)> = current_rules + .difference(&prev_v4_snapshot) + .cloned() + .collect(); + let removed_v6: Vec<(Ipv6Addr, u32)> = prev_v6_snapshot + .difference(¤t_rules_v6) + .cloned() + .collect(); + let added_v6: Vec<(Ipv6Addr, u32)> = current_rules_v6 + .difference(&prev_v6_snapshot) + .cloned() + .collect(); + + // Apply to all BPF skeletons + for s in skels.iter() { + let skel_ref_res = s.lock(); + let skel_ref = match skel_ref_res { + Ok(skel_ref) => skel_ref, + Err(e) => { + let err_msg = format!("Could not take skel mutex in thread: {}", e); + log::error!("{err_msg}"); + return Err(err_msg.into()); + } + }; - // Apply to all BPF skeletons - for s in skels.iter() { - let mut fw = SYNAPSEFirewall::new(s); - if ipv4_changed { - for (net, prefix) in &removed_v4 { - if let Err(e) = fw.unban_ip(*net, *prefix) { - log::error!("IPv4 unban failed for {}/{}: {}", net, prefix, e); + let mut fw = SYNAPSEFirewall::new(&skel_ref); + if ipv4_changed { + for (net, prefix) in &removed_v4 { + if let Err(e) = fw.unban_ip(*net, *prefix) { + log::error!("IPv4 unban failed for {}/{}: {}", net, prefix, e); + } } - } - for (net, prefix) in &added_v4 { - if let Err(e) = fw.ban_ip(*net, *prefix) { - log::error!("IPv4 ban failed for {}/{}: {}", net, prefix, e); + for (net, prefix) in &added_v4 { + if let Err(e) = fw.ban_ip(*net, *prefix) { + log::error!("IPv4 ban failed for {}/{}: {}", net, prefix, e); + } } } - } - if ipv6_changed { - for (net, prefix) in &removed_v6 { - if let Err(e) = fw.unban_ipv6(*net, *prefix) { - log::error!("IPv6 unban failed for {}/{}: {}", net, prefix, e); + if ipv6_changed { + for (net, prefix) in &removed_v6 { + if let Err(e) = fw.unban_ipv6(*net, *prefix) { + log::error!("IPv6 unban failed for {}/{}: {}", net, prefix, e); + } } - } - for (net, prefix) in &added_v6 { - if let Err(e) = fw.ban_ipv6(*net, *prefix) { - log::error!("IPv6 ban failed for {}/{}: {}", net, prefix, e); + for (net, prefix) in &added_v6 { + if let Err(e) = fw.ban_ipv6(*net, *prefix) { + log::error!("IPv6 ban failed for {}/{}: {}", net, prefix, e); + } } } } - } - // Update previous snapshots once after applying to all skels - if ipv4_changed { - *previous_rules_guard = current_rules; + // Update previous snapshots + if ipv4_changed { + *previous_rules_guard = current_rules; + } + if ipv6_changed { + *previous_rules_v6_guard = current_rules_v6; + } + } else { + log::debug!("No IPv4 or IPv6 access rule changes detected, skipping BPF map updates"); } - if ipv6_changed { - *previous_rules_v6_guard = current_rules_v6; + + if rate_limiter_config_changed { + for s in skels.iter() { + let skel_ref_res = s.lock(); + let mut skel_ref = match skel_ref_res { + Ok(skel_ref) => skel_ref, + Err(e) => { + let err_msg = format!("Could not take skel mutex in thread: {}", e); + log::error!("{err_msg}"); + return Err(err_msg.into()); + } + }; + + let mut ratelimit = XDPRateLimit::new(&mut skel_ref); + ratelimit.setup_from_config(&rule.config.rate_limit); + + log::debug!("Successfully set XDP ratelimter config via access rules"); + } + + if rate_limiter_config_changed { + *previous_rate_limiter_config = rule.config.rate_limit.clone(); + } + } else { + log::debug!("No XDP rate limiter config cahnges detected, skipping update"); } Ok(()) diff --git a/src/security/firewall/bpf/compile_flags.txt b/src/security/firewall/bpf/compile_flags.txt new file mode 100644 index 0000000..8e7f5e1 --- /dev/null +++ b/src/security/firewall/bpf/compile_flags.txt @@ -0,0 +1,26 @@ +-I +/home/gepsonka/lib +-I +include +-I +lib +-I +. +-I +/usr/include +-I +/usr/include/x86_64-linux-gnu +-I +/usr/include/bpf +-I +/usr/local/include/bpf +-target +bpf +-O2 +-g +-Wall +-Wno-int-conversion +-Wno-unused-value +-Wno-pointer-sign +-Wno-compare-distinct-pointer-types +-D__TARGET_ARCH_x86 \ No newline at end of file diff --git a/src/security/firewall/bpf/filter.bpf.c b/src/security/firewall/bpf/filter.bpf.c deleted file mode 100644 index 981e2eb..0000000 --- a/src/security/firewall/bpf/filter.bpf.c +++ /dev/null @@ -1,766 +0,0 @@ -#undef bpf // Undefine bpf macro to avoid conflict with struct netns_bpf bpf in vmlinux.h -#include "vmlinux.h" -#include -#include -#include "filter.h" - -#define NF_DROP 0 -#define NF_ACCEPT 1 -#define ETH_P_IP 0x0800 -#define ETH_P_IPV6 0x86DD -#define IP_MF 0x2000 -#define IP_OFFSET 0x1FFF -#define NEXTHDR_FRAGMENT 44 - -// TCP fingerprinting feature flag - comment out to disable and reduce program size -#define ENABLE_TCP_FINGERPRINTING - -// TCP fingerprinting constants -#define TCP_FINGERPRINT_MAX_ENTRIES 10000 -#define TCP_FP_KEY_SIZE 20 // 4 bytes IP + 2 bytes port + 14 bytes fingerprint -#define TCP_FP_MAX_OPTIONS 10 -#define TCP_FP_MAX_OPTION_LEN 16 // Reduced from 40 to minimize BPF instruction count - - -// Fragmentation checks removed - not currently used - - -struct lpm_key { - __u32 prefixlen; - __be32 addr; -}; - -struct lpm_key_v6 { - __u32 prefixlen; - __u8 addr[16]; -}; - -// TCP fingerprinting structures -struct tcp_fingerprint_key { - __be32 src_ip; // Source IP address (IPv4) - __be16 src_port; // Source port - __u8 fingerprint[14]; // TCP fingerprint string (null-terminated) -}; - -struct tcp_fingerprint_key_v6 { - __u8 src_ip[16]; // Source IP address (IPv6) - __be16 src_port; // Source port - __u8 fingerprint[14]; // TCP fingerprint string (null-terminated) -}; - -struct tcp_fingerprint_data { - __u64 first_seen; // Timestamp of first packet - __u64 last_seen; // Timestamp of last packet - __u32 packet_count; // Number of packets seen - __u16 ttl; // Initial TTL - __u16 mss; // Maximum Segment Size - __u16 window_size; // TCP window size - __u8 window_scale; // Window scaling factor - __u8 options_len; // Length of TCP options - __u8 options[TCP_FP_MAX_OPTION_LEN]; // TCP options data -}; - -struct tcp_syn_stats { - __u64 total_syns; - __u64 unique_fingerprints; - __u64 last_reset; -}; - -// IPv4 maps: permanently banned and recently banned -struct { - __uint(type, BPF_MAP_TYPE_LPM_TRIE); - __uint(max_entries, CITADEL_IP_MAP_MAX); - __uint(map_flags, BPF_F_NO_PREALLOC); - __type(key, struct lpm_key); // IPv4 address in network byte order - __type(value, ip_flag_t); // presence flag (1) -} banned_ips SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_LPM_TRIE); - __uint(max_entries, CITADEL_IP_MAP_MAX); - __uint(map_flags, BPF_F_NO_PREALLOC); - __type(key, struct lpm_key); - __type(value, ip_flag_t); -} recently_banned_ips SEC(".maps"); - -// IPv6 maps: permanently banned and recently banned -struct { - __uint(type, BPF_MAP_TYPE_LPM_TRIE); - __uint(max_entries, CITADEL_IP_MAP_MAX); - __uint(map_flags, BPF_F_NO_PREALLOC); - __type(key, struct lpm_key_v6); - __type(value, ip_flag_t); -} banned_ips_v6 SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_LPM_TRIE); - __uint(max_entries, CITADEL_IP_MAP_MAX); - __uint(map_flags, BPF_F_NO_PREALLOC); - __type(key, struct lpm_key_v6); - __type(value, ip_flag_t); -} recently_banned_ips_v6 SEC(".maps"); - -// Remove dynptr helpers, not used in XDP manual parsing -// extern int bpf_dynptr_from_skb(struct __sk_buff *skb, __u64 flags, -// struct bpf_dynptr *ptr__uninit) __ksym; -// extern void *bpf_dynptr_slice(const struct bpf_dynptr *ptr, uint32_t offset, -// void *buffer, uint32_t buffer__sz) __ksym; - -volatile int shootdowns = 0; - -// Statistics maps for tracking access rule hits -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} ipv4_banned_stats SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} ipv4_recently_banned_stats SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} ipv6_banned_stats SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} ipv6_recently_banned_stats SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} total_packets_processed SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} total_packets_dropped SEC(".maps"); - -// TCP fingerprinting maps -struct { - __uint(type, BPF_MAP_TYPE_LRU_HASH); - __uint(max_entries, TCP_FINGERPRINT_MAX_ENTRIES); - __type(key, struct tcp_fingerprint_key); - __type(value, struct tcp_fingerprint_data); -} tcp_fingerprints SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_LRU_HASH); - __uint(max_entries, TCP_FINGERPRINT_MAX_ENTRIES); - __type(key, struct tcp_fingerprint_key_v6); - __type(value, struct tcp_fingerprint_data); -} tcp_fingerprints_v6 SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, struct tcp_syn_stats); -} tcp_syn_stats SEC(".maps"); - -// Blocked TCP fingerprint maps (only store the fingerprint string, not per-IP) -struct { - __uint(type, BPF_MAP_TYPE_HASH); - __uint(max_entries, 10000); // Store up to 10k blocked fingerprint patterns - __type(key, __u8[14]); // TCP fingerprint string (14 bytes) - __type(value, __u8); // Flag (1 = blocked) -} blocked_tcp_fingerprints SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_HASH); - __uint(max_entries, 10000); - __type(key, __u8[14]); // TCP fingerprint string (14 bytes) - __type(value, __u8); // Flag (1 = blocked) -} blocked_tcp_fingerprints_v6 SEC(".maps"); - -// Statistics for TCP fingerprint blocks -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} tcp_fingerprint_blocks_ipv4 SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, __u64); -} tcp_fingerprint_blocks_ipv6 SEC(".maps"); - -// Maps to track dropped IP addresses with counters -struct { - __uint(type, BPF_MAP_TYPE_LRU_HASH); - __uint(max_entries, 1000); // Track up to 1000 unique dropped IPs - __type(key, __be32); // IPv4 address - __type(value, __u64); // Drop count -} dropped_ipv4_addresses SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_LRU_HASH); - __uint(max_entries, 1000); // Track up to 1000 unique dropped IPv6s - __type(key, __u8[16]); // IPv6 address - __type(value, __u64); // Drop count -} dropped_ipv6_addresses SEC(".maps"); - -/* - * Helper for bounds checking and advancing a cursor. - * - * @cursor: pointer to current parsing position - * @end: pointer to end of packet data - * @len: length of the struct to read - * - * Returns a pointer to the struct if it's within bounds, - * and advances the cursor. Returns NULL otherwise. - */ -static void *parse_and_advance(void **cursor, void *end, __u32 len) -{ - void *current = *cursor; - if (current + len > end) - return NULL; - *cursor = current + len; - return current; -} - -/* - * Helper functions for incrementing statistics counters - */ -static void increment_ipv4_banned_stats(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&ipv4_banned_stats, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static void increment_ipv4_recently_banned_stats(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&ipv4_recently_banned_stats, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static void increment_ipv6_banned_stats(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&ipv6_banned_stats, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static void increment_ipv6_recently_banned_stats(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&ipv6_recently_banned_stats, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static void increment_total_packets_processed(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&total_packets_processed, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static void increment_total_packets_dropped(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&total_packets_dropped, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static void increment_dropped_ipv4_address(__be32 ip_addr) -{ - __u64 *value = bpf_map_lookup_elem(&dropped_ipv4_addresses, &ip_addr); - if (value) { - __sync_fetch_and_add(value, 1); - } else { - // First time dropping this IP, initialize counter - __u64 initial_count = 1; - bpf_map_update_elem(&dropped_ipv4_addresses, &ip_addr, &initial_count, BPF_ANY); - } -} - -static void increment_dropped_ipv6_address(struct in6_addr ip_addr) -{ - __u8 *addr_bytes = (__u8 *)&ip_addr; - __u64 *value = bpf_map_lookup_elem(&dropped_ipv6_addresses, addr_bytes); - if (value) { - __sync_fetch_and_add(value, 1); - } else { - // First time dropping this IP, initialize counter - __u64 initial_count = 1; - bpf_map_update_elem(&dropped_ipv6_addresses, addr_bytes, &initial_count, BPF_ANY); - } -} - -/* - * TCP fingerprinting helper functions - */ -static void increment_tcp_syn_stats(void) -{ - __u32 key = 0; - struct tcp_syn_stats *stats = bpf_map_lookup_elem(&tcp_syn_stats, &key); - if (stats) { - __sync_fetch_and_add(&stats->total_syns, 1); - } else { - struct tcp_syn_stats new_stats = {0}; - new_stats.total_syns = 1; - bpf_map_update_elem(&tcp_syn_stats, &key, &new_stats, BPF_ANY); - } -} - -// increment_unique_fingerprints removed - fingerprint recording disabled - -static void increment_tcp_fingerprint_blocks_ipv4(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&tcp_fingerprint_blocks_ipv4, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static void increment_tcp_fingerprint_blocks_ipv6(void) -{ - __u32 key = 0; - __u64 *value = bpf_map_lookup_elem(&tcp_fingerprint_blocks_ipv6, &key); - if (value) { - __sync_fetch_and_add(value, 1); - } -} - -static int parse_tcp_mss_wscale(struct tcphdr *tcp, void *data_end, __u16 *mss_out, __u8 *wscale_out) -{ - __u8 *ptr = (__u8 *)tcp + sizeof(struct tcphdr); - __u32 options_len = (tcp->doff * 4) - sizeof(struct tcphdr); - __u8 *end = ptr + options_len; - - // Ensure we don't exceed packet bounds - if (end > (__u8 *)data_end) { - end = (__u8 *)data_end; - } - - // Safety check - if (ptr >= end || ptr >= (__u8 *)data_end) { - return 0; - } - - // Parse options - limit to 10 iterations to reduce instruction count - #pragma unroll - for (int i = 0; i < 10; i++) { - if (ptr >= end || ptr >= (__u8 *)data_end) break; - if (ptr + 1 > (__u8 *)data_end) break; - - __u8 kind = *ptr; - if (kind == 0) break; // End of options - - if (kind == 1) { - // NOP option - ptr++; - continue; - } - - // Check bounds for option length - if (ptr + 2 > (__u8 *)data_end) break; - __u8 len = *(ptr + 1); - if (len < 2 || ptr + len > (__u8 *)data_end) break; - - // MSS option (kind=2, len=4) - if (kind == 2 && len == 4 && ptr + 4 <= (__u8 *)data_end) { - *mss_out = (*(ptr + 2) << 8) | *(ptr + 3); - } - // Window scale option (kind=3, len=3) - else if (kind == 3 && len == 3 && ptr + 3 <= (__u8 *)data_end) { - *wscale_out = *(ptr + 2); - } - - ptr += len; - } - - return 0; -} - -static void generate_tcp_fingerprint(struct tcphdr *tcp, void *data_end, __u16 ttl, __u8 *fingerprint) -{ - // Generate JA4T-style fingerprint: ttl:mss:window:scale - __u16 mss = 0; - __u8 window_scale = 0; - - // Parse TCP options to extract MSS and window scaling - parse_tcp_mss_wscale(tcp, data_end, &mss, &window_scale); - - // Generate fingerprint string manually (BPF doesn't support complex formatting) - __u16 window = bpf_ntohs(tcp->window); - - // Format: "ttl:mss:window:scale" (max 14 chars) - fingerprint[0] = '0' + (ttl / 100); - fingerprint[1] = '0' + ((ttl / 10) % 10); - fingerprint[2] = '0' + (ttl % 10); - fingerprint[3] = ':'; - fingerprint[4] = '0' + (mss / 1000); - fingerprint[5] = '0' + ((mss / 100) % 10); - fingerprint[6] = '0' + ((mss / 10) % 10); - fingerprint[7] = '0' + (mss % 10); - fingerprint[8] = ':'; - fingerprint[9] = '0' + (window / 10000); - fingerprint[10] = '0' + ((window / 1000) % 10); - fingerprint[11] = '0' + ((window / 100) % 10); - fingerprint[12] = '0' + ((window / 10) % 10); - fingerprint[13] = '0' + (window % 10); - // Note: window_scale is not included due to space constraints -} - -/* - * Check if a TCP fingerprint is blocked (IPv4) - * Returns true if the fingerprint should be blocked - */ -static bool is_tcp_fingerprint_blocked(__u8 *fingerprint) -{ - __u8 *blocked = bpf_map_lookup_elem(&blocked_tcp_fingerprints, fingerprint); - return (blocked != NULL && *blocked == 1); -} - -/* - * Check if a TCP fingerprint is blocked (IPv6) - * Returns true if the fingerprint should be blocked - */ -static bool is_tcp_fingerprint_blocked_v6(__u8 *fingerprint) -{ - __u8 *blocked = bpf_map_lookup_elem(&blocked_tcp_fingerprints_v6, fingerprint); - return (blocked != NULL && *blocked == 1); -} - -// TCP fingerprint recording info struct to pass data with fewer args -struct fp_record_info { - __be32 src_ip_v4; - __be16 src_port; - __u16 ttl; - __u8 fingerprint[14]; - __u8 src_ip_v6[16]; -}; - -/* - * Record TCP fingerprint for monitoring (IPv4) - * Uses struct to pass data within BPF 5-arg limit - */ -static void record_tcp_fingerprint_v4(struct fp_record_info *info, - struct tcphdr *tcp, void *data_end) -{ - // Skip localhost (127.0.0.0/8) - if ((info->src_ip_v4 & bpf_htonl(0xff000000)) == bpf_htonl(0x7f000000)) - return; - - struct tcp_fingerprint_key key = {0}; - key.src_ip = info->src_ip_v4; - key.src_port = info->src_port; - __builtin_memcpy(key.fingerprint, info->fingerprint, 14); - - // Only create new entries - if (bpf_map_lookup_elem(&tcp_fingerprints, &key)) - return; - - struct tcp_fingerprint_data data = {0}; - data.first_seen = bpf_ktime_get_ns(); - data.last_seen = data.first_seen; - data.packet_count = 1; - data.ttl = info->ttl; - data.window_size = bpf_ntohs(tcp->window); - parse_tcp_mss_wscale(tcp, data_end, &data.mss, &data.window_scale); - - bpf_map_update_elem(&tcp_fingerprints, &key, &data, BPF_NOEXIST); -} - -/* - * Record TCP fingerprint for monitoring (IPv6) - */ -static void record_tcp_fingerprint_v6(struct fp_record_info *info, - struct tcphdr *tcp, void *data_end) -{ - struct tcp_fingerprint_key_v6 key = {0}; - __builtin_memcpy(key.src_ip, info->src_ip_v6, 16); - key.src_port = info->src_port; - __builtin_memcpy(key.fingerprint, info->fingerprint, 14); - - // Only create new entries - if (bpf_map_lookup_elem(&tcp_fingerprints_v6, &key)) - return; - - struct tcp_fingerprint_data data = {0}; - data.first_seen = bpf_ktime_get_ns(); - data.last_seen = data.first_seen; - data.packet_count = 1; - data.ttl = info->ttl; - data.window_size = bpf_ntohs(tcp->window); - parse_tcp_mss_wscale(tcp, data_end, &data.mss, &data.window_scale); - - bpf_map_update_elem(&tcp_fingerprints_v6, &key, &data, BPF_NOEXIST); -} - -SEC("xdp") -int arxignis_xdp_filter(struct xdp_md *ctx) -{ - // This filter is designed to only block incoming traffic - // It should be attached only to ingress hooks, not egress - // The filtering logic below blocks packets based on source IP addresses - // - // IP Version Support: - // - Supports IPv4-only, IPv6-only, and hybrid (both) modes - // - Note: XDP requires IPv6 to be enabled at kernel level for attachment, - // even when processing only IPv4 packets. This is a kernel limitation. - // - The BPF program processes both IPv4 and IPv6 packets based on the - // ethernet protocol type (ETH_P_IP for IPv4, ETH_P_IPV6 for IPv6) - - void *data_end = (void *)(long)ctx->data_end; - void *cursor = (void *)(long)ctx->data; - - // Debug: Count all packets - __u32 zero = 0; - __u32 *packet_count = bpf_map_lookup_elem(&total_packets_processed, &zero); - if (packet_count) { - __sync_fetch_and_add(packet_count, 1); - } - - struct ethhdr *eth = parse_and_advance(&cursor, data_end, sizeof(*eth)); - if (!eth) - return XDP_PASS; - - __u16 h_proto = eth->h_proto; - - // Increment total packets processed counter - increment_total_packets_processed(); - - if (h_proto == bpf_htons(ETH_P_IP)) { - struct iphdr *iph = parse_and_advance(&cursor, data_end, sizeof(*iph)); - if (!iph) - return XDP_PASS; - - struct lpm_key key = { - .prefixlen = 32, - .addr = iph->saddr, - }; - - if (bpf_map_lookup_elem(&banned_ips, &key)) { - increment_ipv4_banned_stats(); - increment_total_packets_dropped(); - increment_dropped_ipv4_address(iph->saddr); - //bpf_printk("XDP: BLOCKED incoming permanently banned IPv4 %pI4", &iph->saddr); - return XDP_DROP; - } - - if (bpf_map_lookup_elem(&recently_banned_ips, &key)) { - increment_ipv4_recently_banned_stats(); - // Block UDP and ICMP from recently banned IPs, but allow DNS - if (iph->protocol == IPPROTO_UDP) { - struct udphdr *udph = parse_and_advance(&cursor, data_end, sizeof(*udph)); - if (udph && udph->dest == bpf_htons(53)) { - return XDP_PASS; // Allow DNS responses - } - // Block other UDP traffic - ip_flag_t one = 1; - bpf_map_update_elem(&banned_ips, &key, &one, BPF_ANY); - bpf_map_delete_elem(&recently_banned_ips, &key); - increment_total_packets_dropped(); - increment_dropped_ipv4_address(iph->saddr); - //bpf_printk("XDP: BLOCKED incoming UDP from recently banned IPv4 %pI4, promoted to permanent ban", &iph->saddr); - return XDP_DROP; - } - if (iph->protocol == IPPROTO_ICMP) { - ip_flag_t one = 1; - bpf_map_update_elem(&banned_ips, &key, &one, BPF_ANY); - bpf_map_delete_elem(&recently_banned_ips, &key); - increment_total_packets_dropped(); - increment_dropped_ipv4_address(iph->saddr); - //bpf_printk("XDP: BLOCKED incoming ICMP from recently banned IPv4 %pI4, promoted to permanent ban", &iph->saddr); - return XDP_DROP; - } - // For TCP, promote to banned on FIN/RST - if (iph->protocol == IPPROTO_TCP) { - struct tcphdr *tcph = parse_and_advance(&cursor, data_end, sizeof(*tcph)); - if (tcph && (tcph->fin || tcph->rst)) { - ip_flag_t one = 1; - bpf_map_update_elem(&banned_ips, &key, &one, BPF_ANY); - bpf_map_delete_elem(&recently_banned_ips, &key); - increment_total_packets_dropped(); - increment_dropped_ipv4_address(iph->saddr); - } - } - return XDP_PASS; - } - - // Perform TCP fingerprinting ONLY on SYN packets - if (iph->protocol == IPPROTO_TCP) { - struct tcphdr *tcph = parse_and_advance(&cursor, data_end, sizeof(*tcph)); - if (tcph) { - // Only fingerprint SYN packets (not SYN-ACK) to capture MSS/WSCALE - if (tcph->syn && !tcph->ack) { - increment_tcp_syn_stats(); - - // Generate fingerprint to check if blocked - __u8 fingerprint[14] = {0}; - generate_tcp_fingerprint(tcph, data_end, iph->ttl, fingerprint); - - // Check if this TCP fingerprint is blocked - if (is_tcp_fingerprint_blocked(fingerprint)) { - increment_tcp_fingerprint_blocks_ipv4(); - increment_total_packets_dropped(); - increment_dropped_ipv4_address(iph->saddr); - return XDP_DROP; - } - // Record fingerprint for monitoring - struct fp_record_info fp_info = {0}; - fp_info.src_ip_v4 = iph->saddr; - fp_info.src_port = tcph->source; - fp_info.ttl = iph->ttl; - __builtin_memcpy(fp_info.fingerprint, fingerprint, 14); - record_tcp_fingerprint_v4(&fp_info, tcph, data_end); - } - } - } - - return XDP_PASS; - } - else if (h_proto == bpf_htons(ETH_P_IPV6)) { - struct ipv6hdr *ip6h = parse_and_advance(&cursor, data_end, sizeof(*ip6h)); - if (!ip6h) - return XDP_PASS; - - // Always allow DNS traffic (UDP port 53) to pass through - if (ip6h->nexthdr == IPPROTO_UDP) { - struct udphdr *udph = parse_and_advance(&cursor, data_end, sizeof(*udph)); - if (udph && (udph->dest == bpf_htons(53) || udph->source == bpf_htons(53))) { - return XDP_PASS; // Always allow DNS traffic - } - } - - // Check banned/recently banned maps by source IPv6 - struct lpm_key_v6 key6 = { - .prefixlen = 128, - }; - __u8 *src_addr = (__u8 *)&ip6h->saddr; - __builtin_memcpy(key6.addr, src_addr, 16); - - if (bpf_map_lookup_elem(&banned_ips_v6, &key6)) { - increment_ipv6_banned_stats(); - increment_total_packets_dropped(); - increment_dropped_ipv6_address(ip6h->saddr); - //bpf_printk("XDP: BLOCKED incoming permanently banned IPv6"); - return XDP_DROP; - } - - if (bpf_map_lookup_elem(&recently_banned_ips_v6, &key6)) { - increment_ipv6_recently_banned_stats(); - // Block UDP and ICMP from recently banned IPv6 IPs, but allow DNS - if (ip6h->nexthdr == IPPROTO_UDP) { - struct udphdr *udph = parse_and_advance(&cursor, data_end, sizeof(*udph)); - if (udph && udph->dest == bpf_htons(53)) { - return XDP_PASS; // Allow DNS responses - } - // Block other UDP traffic - ip_flag_t one = 1; - bpf_map_update_elem(&banned_ips_v6, &key6, &one, BPF_ANY); - bpf_map_delete_elem(&recently_banned_ips_v6, &key6); - increment_total_packets_dropped(); - increment_dropped_ipv6_address(ip6h->saddr); - //bpf_printk("XDP: BLOCKED incoming UDP from recently banned IPv6, promoted to permanent ban"); - return XDP_DROP; - } - if (ip6h->nexthdr == 58) { // 58 = IPPROTO_ICMPV6 - ip_flag_t one = 1; - bpf_map_update_elem(&banned_ips_v6, &key6, &one, BPF_ANY); - bpf_map_delete_elem(&recently_banned_ips_v6, &key6); - increment_total_packets_dropped(); - increment_dropped_ipv6_address(ip6h->saddr); - //bpf_printk("XDP: BLOCKED incoming ICMPv6 from recently banned IPv6, promoted to permanent ban"); - return XDP_DROP; - } - // For TCP, only promote to banned on FIN/RST - if (ip6h->nexthdr == IPPROTO_TCP) { - struct tcphdr *tcph = parse_and_advance(&cursor, data_end, sizeof(*tcph)); - if (tcph) { - if (tcph->fin || tcph->rst) { - ip_flag_t one = 1; - bpf_map_update_elem(&banned_ips_v6, &key6, &one, BPF_ANY); - bpf_map_delete_elem(&recently_banned_ips_v6, &key6); - increment_total_packets_dropped(); - increment_dropped_ipv6_address(ip6h->saddr); - //bpf_printk("XDP: TCP FIN/RST from incoming recently banned IPv6, promoted to permanent ban"); - } - } - } - return XDP_PASS; // Allow if recently banned - } - - // Perform TCP fingerprinting on IPv6 TCP packets - if (ip6h->nexthdr == IPPROTO_TCP) { - struct tcphdr *tcph = parse_and_advance(&cursor, data_end, sizeof(*tcph)); - if (tcph) { - // Perform TCP fingerprinting ONLY on SYN packets (not SYN-ACK) - // This ensures we capture the initial handshake with MSS/WSCALE - if (tcph->syn && !tcph->ack) { - // Skip IPv6 localhost traffic (::1) - simplified check - __u8 *src_addr = (__u8 *)&ip6h->saddr; - // Quick localhost check: first 8 bytes zero, bytes 8-14 zero, byte 15 is 1 - if (src_addr[0] == 0 && src_addr[7] == 0 && - src_addr[8] == 0 && src_addr[14] == 0 && src_addr[15] == 1) { - return XDP_PASS; - } - - // Extract TTL from IPv6 hop limit - __u16 ttl = ip6h->hop_limit; - - // Generate fingerprint to check if blocked - __u8 fingerprint[14] = {0}; - generate_tcp_fingerprint(tcph, data_end, ttl, fingerprint); - - // Check if this TCP fingerprint is blocked - if (is_tcp_fingerprint_blocked_v6(fingerprint)) { - increment_tcp_fingerprint_blocks_ipv6(); - increment_total_packets_dropped(); - increment_dropped_ipv6_address(ip6h->saddr); - return XDP_DROP; - } - // Record fingerprint for monitoring - struct fp_record_info fp_info = {0}; - __builtin_memcpy(fp_info.src_ip_v6, src_addr, 16); - fp_info.src_port = tcph->source; - fp_info.ttl = ttl; - __builtin_memcpy(fp_info.fingerprint, fingerprint, 14); - record_tcp_fingerprint_v6(&fp_info, tcph, data_end); - } - } - } - - return XDP_PASS; - } - - return XDP_PASS; - // return XDP_ABORTED; -} - -char _license[] SEC("license") = "GPL"; diff --git a/src/security/firewall/bpf/filter.h b/src/security/firewall/bpf/filter.h deleted file mode 100644 index 747c731..0000000 --- a/src/security/firewall/bpf/filter.h +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: GPL-2.0 -#pragma once - -// Maximum entries for IP maps -#ifndef CITADEL_IP_MAP_MAX -#define CITADEL_IP_MAP_MAX 65536 -#endif - -// Map value is a simple flag (present = 1) -typedef __u8 ip_flag_t; - diff --git a/src/security/firewall/bpf/include/common.h b/src/security/firewall/bpf/include/common.h new file mode 100644 index 0000000..c836bfc --- /dev/null +++ b/src/security/firewall/bpf/include/common.h @@ -0,0 +1,32 @@ +#pragma once + +#include "vmlinux.h" + +// #include +// #include +// #include + +// #include +// #include +// #include +// #include +// #include +// #include +// #include + +#include +#include + +#define NF_DROP 0 +#define NF_ACCEPT 1 +#define ETH_P_IP 0x0800 +#define ETH_P_IPV6 0x86DD +#define IP_MF 0x2000 +#define IP_OFFSET 0x1FFF +#define NEXTHDR_FRAGMENT 44 + +#define CITADEL_IP_MAP_MAX 65536 + +typedef __u8 ip_flag_t; + +typedef __u8 ipv6_addr[16]; \ No newline at end of file diff --git a/src/security/firewall/bpf/lib/firewall.h b/src/security/firewall/bpf/lib/firewall.h new file mode 100644 index 0000000..2bfddc7 --- /dev/null +++ b/src/security/firewall/bpf/lib/firewall.h @@ -0,0 +1,181 @@ +#pragma once + +#include "common.h" + +#include "../xdp_maps.h" +#include "vmlinux.h" + +struct lpm_key { + __u32 prefixlen; + __be32 addr; +}; + +struct lpm_key_v6 { + __u32 prefixlen; + __u8 addr[16]; +}; + +struct src_port_key_v4 { + __be32 addr; + __be16 port; +}; + +struct src_port_key_v6 { + __u8 addr[16]; + __be16 port; +}; + +/* + * Helper functions for incrementing statistics counters + */ +static __always_inline void increment_ipv4_banned_stats(void) { + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&ipv4_banned_stats, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static __always_inline void increment_ipv4_recently_banned_stats(void) { + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&ipv4_recently_banned_stats, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static __always_inline void increment_ipv6_banned_stats(void) { + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&ipv6_banned_stats, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static __always_inline void increment_ipv6_recently_banned_stats(void) { + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&ipv6_recently_banned_stats, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static __always_inline void increment_total_packets_processed(void) { + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&total_packets_processed, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static __always_inline void increment_total_packets_dropped(void) { + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&total_packets_dropped, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static __always_inline void increment_dropped_ipv4_address(__be32 ip_addr) { + __u64 *value = bpf_map_lookup_elem(&dropped_ipv4_addresses, &ip_addr); + if (value) { + __sync_fetch_and_add(value, 1); + } else { + // First time dropping this IP, initialize counter + __u64 initial_count = 1; + bpf_map_update_elem(&dropped_ipv4_addresses, &ip_addr, &initial_count, + BPF_ANY); + } +} + +static __always_inline void +increment_dropped_ipv6_address(struct in6_addr ip_addr) { + __u8 *addr_bytes = (__u8 *)&ip_addr; + __u64 *value = bpf_map_lookup_elem(&dropped_ipv6_addresses, addr_bytes); + if (value) { + __sync_fetch_and_add(value, 1); + } else { + // First time dropping this IP, initialize counter + __u64 initial_count = 1; + bpf_map_update_elem(&dropped_ipv6_addresses, addr_bytes, &initial_count, + BPF_ANY); + } +} + +static __noinline int xdp_portban(struct iphdr *iph, struct tcphdr *tcph, + struct udphdr *udph) { + + if (!iph) { + return XDP_PASS; + } + + __be16 src_port = 0; + __be16 dst_port = 0; + + if (tcph) { + src_port = tcph->source; + dst_port = tcph->dest; + } else if (udph) { + src_port = udph->source; + dst_port = udph->dest; + } else { + return XDP_PASS; + } + + struct src_port_key_v4 inbound_key = { + .addr = iph->saddr, + .port = src_port, + }; + + if (bpf_map_lookup_elem(&banned_inbound_ipv4_address_ports, &inbound_key)) { + increment_total_packets_dropped(); + increment_dropped_ipv4_address(iph->saddr); + return XDP_DROP; + } + + struct src_port_key_v4 outbound_key = { + .addr = iph->daddr, + .port = dst_port, + }; + + if (bpf_map_lookup_elem(&banned_outbound_ipv4_address_ports, &outbound_key)) { + increment_total_packets_dropped(); + increment_dropped_ipv4_address(iph->daddr); + return XDP_DROP; + } + + return XDP_PASS; +} + +static __noinline int xdp_firewall(struct iphdr *iph, struct ipv6hdr *ip6h) { + + if (iph) { + struct lpm_key key = { + .prefixlen = 32, + .addr = iph->saddr, + }; + + if (bpf_map_lookup_elem(&banned_ips, &key)) { + increment_ipv4_banned_stats(); + increment_total_packets_dropped(); + increment_dropped_ipv4_address(iph->saddr); + // bpf_printk("XDP: BLOCKED incoming permanently banned IPv4 %pI4", + // &iph->saddr); + return XDP_DROP; + } + } else if (ip6h) { + struct lpm_key_v6 key6 = {.prefixlen = 128}; + + __builtin_memcpy(&key6.addr, &ip6h->saddr, sizeof(ip6h->saddr)); + + if (bpf_map_lookup_elem(&banned_ips_v6, &key6)) { + increment_ipv6_banned_stats(); + increment_total_packets_dropped(); + increment_dropped_ipv6_address(ip6h->saddr); + // bpf_printk("XDP: BLOCKED incoming permanently banned IPv6"); + return XDP_DROP; + } + } + + return XDP_PASS; +} \ No newline at end of file diff --git a/src/security/firewall/bpf/lib/helper.h b/src/security/firewall/bpf/lib/helper.h new file mode 100644 index 0000000..603a1bf --- /dev/null +++ b/src/security/firewall/bpf/lib/helper.h @@ -0,0 +1,37 @@ +#pragma once + +#include "common.h" +#include "vmlinux.h" +#include + +static __always_inline bool is_frag_v4(const struct iphdr *iph) { + return (iph->frag_off & bpf_htons(IP_MF | IP_OFFSET)) != 0; +} + +static __always_inline bool is_frag_v6(const struct ipv6hdr *ip6h) { + return ip6h->nexthdr == NEXTHDR_FRAGMENT; +} + +/* + * Helper for bounds checking and advancing a cursor. + * + * @cursor: pointer to current parsing position + * @end: pointer to end of packet data + * @len: length of the struct to read + * + * Returns a pointer to the struct if it's within bounds, + * and advances the cursor. Returns NULL otherwise. + */ +static __always_inline void *parse_and_advance(void **cursor, void *end, + __u32 len) { + void *current = *cursor; + if (current + len > end) + return NULL; + *cursor = current + len; + return current; +} + +static __always_inline void copy_ipv6_addr_as_array(ipv6_addr *dest, + struct in6_addr *src) { + __builtin_memcpy(dest, src, sizeof(*src)); +} \ No newline at end of file diff --git a/src/security/firewall/bpf/lib/ratelimit.h b/src/security/firewall/bpf/lib/ratelimit.h new file mode 100644 index 0000000..ceca571 --- /dev/null +++ b/src/security/firewall/bpf/lib/ratelimit.h @@ -0,0 +1,135 @@ +#pragma once + +#include "common.h" + +#include "../xdp_maps.h" +#include "vmlinux.h" + +struct ratelimiter_config_t { + __u64 TOKENS_PER_REQUEST; // tokens consumed per request + __u64 REFILL_RATE; // tokens added per sec + __u64 MAX_BUCKET_CAPACITY; // refill_rate * (max_bucket_capacity / + // refill_rate) = allow x request burst + __u8 ENABLED; +}; + +// set this before loading the script +volatile struct ratelimiter_config_t ratelimiter_config = {1, 1000, 3000, + false}; + +static __always_inline bool is_ipv4_whitelisted(__be32 *addr) { + return !!bpf_map_lookup_elem(&ipv4_ratelimit_whitelist, addr); +} + +static __always_inline bool is_ipv6_whitelisted(ipv6_addr *addr) { + return !!bpf_map_lookup_elem(&ipv6_ratelimit_whitelist, addr); +} + +static __always_inline void refill_tokens(struct ratelimit_bucket_value *rl_val, + __u64 now) { + + __u64 elapsed_ns = now - rl_val->last_topup; + + // Calculate tokens to add: (elapsed_seconds * REFILL_RATE) + __u64 tokens_to_add = + (elapsed_ns * ratelimiter_config.REFILL_RATE) / 1000000000ULL; + + if (tokens_to_add > 0) { + rl_val->num_of_tokens += tokens_to_add; + + if (rl_val->num_of_tokens > ratelimiter_config.MAX_BUCKET_CAPACITY) { + rl_val->num_of_tokens = ratelimiter_config.MAX_BUCKET_CAPACITY; + } + + rl_val->last_topup = now; + } +} + +static __noinline __u8 ipv4_syn_ratelimit(__be32 *addr, struct tcphdr *tcph) { + if (is_ipv4_whitelisted(addr)) { + return XDP_PASS; + } + + __u64 now = bpf_ktime_get_ns(); + + struct ratelimit_bucket_value *bucket = + bpf_map_lookup_elem(&ipv4_syn_bucket_store, addr); + + // bpf_printk("Bucket status -> tokens: %llu\n", bucket->num_of_tokens); + + if (!bucket) { + struct ratelimit_bucket_value new_bucket = {}; + new_bucket.last_topup = now; + new_bucket.num_of_tokens = ratelimiter_config.MAX_BUCKET_CAPACITY; + bpf_map_update_elem(&ipv4_syn_bucket_store, addr, &new_bucket, BPF_ANY); + // bpf_printk("Bucket created for addr: %u, topup: %ld\n", *addr, now); + return XDP_PASS; + } + + refill_tokens(bucket, now); + + if (bucket->num_of_tokens >= ratelimiter_config.TOKENS_PER_REQUEST) { + bucket->num_of_tokens -= ratelimiter_config.TOKENS_PER_REQUEST; + // bpf_printk("Packet passed for addr: %d, num of tokens: %llu\n", *addr, + // bucket->num_of_tokens); + return XDP_PASS; + } + + // bpf_printk("Packet dropped for addr: %u\n", *addr); + return XDP_DROP; +} + +static __noinline __u8 ipv6_syn_ratelimit(ipv6_addr *addr, + struct tcphdr *tcph) { + if (is_ipv6_whitelisted(addr)) { + return XDP_PASS; + } + + __u64 now = bpf_ktime_get_ns(); + struct ratelimit_bucket_value *bucket = + bpf_map_lookup_elem(&ipv6_syn_bucket_store, addr); + + if (!bucket) { + struct ratelimit_bucket_value new_bucket = {}; + new_bucket.last_topup = now; + new_bucket.num_of_tokens = ratelimiter_config.MAX_BUCKET_CAPACITY; + bpf_map_update_elem(&ipv6_syn_bucket_store, addr, &new_bucket, BPF_ANY); + + return XDP_PASS; + } + + refill_tokens(bucket, now); + + if (bucket->num_of_tokens >= ratelimiter_config.TOKENS_PER_REQUEST) { + bucket->num_of_tokens -= ratelimiter_config.TOKENS_PER_REQUEST; + return XDP_PASS; + } + + return XDP_DROP; +} + +int __noinline xdp_ratelimit(struct iphdr *iph, struct ipv6hdr *ip6h, + struct tcphdr *tcph) { + if (!ratelimiter_config.ENABLED) { + return XDP_PASS; + } + + if (tcph) { + if (!tcph->syn || tcph->ack) { + return XDP_PASS; + } + if (iph) { + if (ipv4_syn_ratelimit(&iph->saddr, tcph) == XDP_DROP) { + return XDP_DROP; + } + } else if (ip6h) { + ipv6_addr *ipv6_saddr_ptr = (ipv6_addr *)&ip6h->saddr; + + if (ipv6_syn_ratelimit(ipv6_saddr_ptr, tcph) == XDP_DROP) { + return XDP_DROP; + } + } + } + + return XDP_PASS; +} diff --git a/src/security/firewall/bpf/lib/tcp_fingerprinting.h b/src/security/firewall/bpf/lib/tcp_fingerprinting.h new file mode 100644 index 0000000..25108bb --- /dev/null +++ b/src/security/firewall/bpf/lib/tcp_fingerprinting.h @@ -0,0 +1,359 @@ +#pragma once + +#include "common.h" + +#include "../xdp_maps.h" +#include "firewall.h" +#include "helper.h" + +// TCP fingerprinting constants + +#define TCP_FP_KEY_SIZE 20 // 4 bytes IP + 2 bytes port + 14 bytes fingerprint +#define TCP_FP_MAX_OPTIONS 10 +#define TCP_FP_MAX_OPTION_LEN 40 + +// TCP fingerprinting structures +struct tcp_fingerprint_key { + __be32 src_ip; // Source IP address (IPv4) + __be16 src_port; // Source port + __u8 fingerprint[14]; // TCP fingerprint string (null-terminated) +}; + +struct tcp_fingerprint_key_v6 { + __u8 src_ip[16]; // Source IP address (IPv6) + __be16 src_port; // Source port + __u8 fingerprint[14]; // TCP fingerprint string (null-terminated) +}; + +struct tcp_fingerprint_data { + __u64 first_seen; // Timestamp of first packet + __u64 last_seen; // Timestamp of last packet + __u32 packet_count; // Number of packets seen + __u16 ttl; // Initial TTL + __u16 mss; // Maximum Segment Size + __u16 window_size; // TCP window size + __u8 window_scale; // Window scaling factor + __u8 options_len; // Length of TCP options + __u8 options[TCP_FP_MAX_OPTION_LEN]; // TCP options data +}; + +struct tcp_syn_stats { + __u64 total_syns; + __u64 unique_fingerprints; + __u64 last_reset; +}; + +/* + * TCP fingerprinting helper functions + */ +static __always_inline void increment_tcp_syn_stats(void) { + __u32 key = 0; + struct tcp_syn_stats *stats = bpf_map_lookup_elem(&tcp_syn_stats, &key); + if (stats) { + __sync_fetch_and_add(&stats->total_syns, 1); + } else { + struct tcp_syn_stats new_stats = {0}; + new_stats.total_syns = 1; + bpf_map_update_elem(&tcp_syn_stats, &key, &new_stats, BPF_ANY); + } +} + +static __always_inline void increment_unique_fingerprints(void) { + __u32 key = 0; + struct tcp_syn_stats *stats = bpf_map_lookup_elem(&tcp_syn_stats, &key); + if (stats) { + __sync_fetch_and_add(&stats->unique_fingerprints, 1); + } +} + +static __always_inline void increment_tcp_fingerprint_blocks_ipv4(void) { + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&tcp_fingerprint_blocks_ipv4, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static __always_inline void increment_tcp_fingerprint_blocks_ipv6(void) { + __u32 key = 0; + __u64 *value = bpf_map_lookup_elem(&tcp_fingerprint_blocks_ipv6, &key); + if (value) { + __sync_fetch_and_add(value, 1); + } +} + +static __always_inline int parse_tcp_mss_wscale(struct tcphdr *tcp, + void *data_end, __u16 *mss_out, + __u8 *wscale_out) { + if ((void *)tcp + sizeof(struct tcphdr) > data_end) { + return 0; + } + + __u8 *ptr = (__u8 *)tcp + sizeof(struct tcphdr); + __u32 options_len = (tcp->doff * 4) - sizeof(struct tcphdr); + + // Guard against invalid doff + if (options_len > 40) { // Max TCP options length + options_len = 40; + } + + __u8 *end = ptr + options_len; + + // Ensure we don't exceed packet bounds + if (end > (__u8 *)data_end) { + end = (__u8 *)data_end; + } + + // Safety check + if (ptr >= end) { + return 0; + } + +// Parse options - limit to 20 iterations to handle NOPs +#pragma unroll + for (int i = 0; i < 20; i++) { + if (ptr >= end || ptr >= (__u8 *)data_end) + break; + if (ptr + 1 > (__u8 *)data_end) + break; + + __u8 kind = *ptr; + if (kind == 0) + break; // End of options + + if (kind == 1) { + // NOP option + ptr++; + continue; + } + + // Check bounds for option length + if (ptr + 2 > (__u8 *)data_end) + break; + __u8 len = *(ptr + 1); + if (len < 2 || ptr + len > (__u8 *)data_end) + break; + + // MSS option (kind=2, len=4) + if (kind == 2 && len == 4 && ptr + 4 <= (__u8 *)data_end) { + *mss_out = (*(ptr + 2) << 8) | *(ptr + 3); + } + // Window scale option (kind=3, len=3) + else if (kind == 3 && len == 3 && ptr + 3 <= (__u8 *)data_end) { + *wscale_out = *(ptr + 2); + } + + ptr += len; + } + + return 0; +} + +static __always_inline void generate_tcp_fingerprint(struct tcphdr *tcp, + void *data_end, __u16 ttl, + __u8 *fingerprint) { + // Generate JA4T-style fingerprint: ttl:mss:window:scale + __u16 mss = 0; + __u8 window_scale = 0; + + if ((void *)tcp + sizeof(struct tcphdr) > data_end) { + return; + } + + // Parse TCP options to extract MSS and window scaling + parse_tcp_mss_wscale(tcp, data_end, &mss, &window_scale); + + // Generate fingerprint string manually (BPF doesn't support complex + // formatting) + __u16 window = bpf_ntohs(tcp->window); + + // Format: "ttl:mss:window:scale" (max 14 chars) + fingerprint[0] = '0' + (ttl / 100); + fingerprint[1] = '0' + ((ttl / 10) % 10); + fingerprint[2] = '0' + (ttl % 10); + fingerprint[3] = ':'; + fingerprint[4] = '0' + (mss / 1000); + fingerprint[5] = '0' + ((mss / 100) % 10); + fingerprint[6] = '0' + ((mss / 10) % 10); + fingerprint[7] = '0' + (mss % 10); + fingerprint[8] = ':'; + fingerprint[9] = '0' + (window / 10000); + fingerprint[10] = '0' + ((window / 1000) % 10); + fingerprint[11] = '0' + ((window / 100) % 10); + fingerprint[12] = '0' + ((window / 10) % 10); + fingerprint[13] = '0' + (window % 10); + // Note: window_scale is not included due to space constraints +} + +/* + * Check if a TCP fingerprint is blocked (IPv4) + * Returns true if the fingerprint should be blocked + */ +static __always_inline bool is_tcp_fingerprint_blocked(__u8 *fingerprint) { + __u8 *blocked = bpf_map_lookup_elem(&blocked_tcp_fingerprints, fingerprint); + return (blocked != NULL && *blocked == 1); +} + +/* + * Check if a TCP fingerprint is blocked (IPv6) + * Returns true if the fingerprint should be blocked + */ +static __always_inline bool is_tcp_fingerprint_blocked_v6(__u8 *fingerprint) { + __u8 *blocked = + bpf_map_lookup_elem(&blocked_tcp_fingerprints_v6, fingerprint); + return (blocked != NULL && *blocked == 1); +} + +static __always_inline void record_tcp_fingerprint(__be32 src_ip, + __be16 src_port, + struct tcphdr *tcp, + void *data_end, __u16 ttl) { + // Skip localhost traffic to reduce noise + // Check for 127.0.0.0/8 range (127.0.0.1 to 127.255.255.255) + if ((src_ip & bpf_htonl(0xff000000)) == bpf_htonl(0x7f000000)) { + return; + } + + struct tcp_fingerprint_key key = {0}; + struct tcp_fingerprint_data data = {0}; + __u64 timestamp = bpf_ktime_get_ns(); + + key.src_ip = src_ip; + key.src_port = src_port; + + // Generate fingerprint + generate_tcp_fingerprint(tcp, data_end, ttl, key.fingerprint); + + // Check if fingerprint already exists + struct tcp_fingerprint_data *existing = + bpf_map_lookup_elem(&tcp_fingerprints, &key); + if (existing) { + // Update existing entry - must copy to local variable first + data.first_seen = existing->first_seen; + data.last_seen = timestamp; + data.packet_count = existing->packet_count + 1; + data.ttl = existing->ttl; + data.mss = existing->mss; + data.window_size = existing->window_size; + data.window_scale = existing->window_scale; + data.options_len = existing->options_len; + + // Copy options array + __builtin_memcpy(data.options, existing->options, TCP_FP_MAX_OPTION_LEN); + + bpf_map_update_elem(&tcp_fingerprints, &key, &data, BPF_ANY); + } else { + // Create new entry + data.first_seen = timestamp; + data.last_seen = timestamp; + data.packet_count = 1; + data.ttl = ttl; + data.window_size = bpf_ntohs(tcp->window); + + // Extract MSS and window scale from options + parse_tcp_mss_wscale(tcp, data_end, &data.mss, &data.window_scale); + + bpf_map_update_elem(&tcp_fingerprints, &key, &data, BPF_ANY); + increment_unique_fingerprints(); + + // Log new TCP fingerprint + // bpf_printk("TCP_FP: New fingerprint from %pI4:%d - TTL:%d MSS:%d WS:%d + // Window:%d", + // &src_ip, bpf_ntohs(src_port), ttl, data.mss, data.window_scale, + // data.window_size); + } +} + +static __noinline int xdp_tcp_fingerprinting(struct xdp_md *ctx, + struct iphdr *iph, + struct ipv6hdr *ip6h, + struct tcphdr *tcph) { + if (tcph) { + if (iph) { + if (tcph->syn && !tcph->ack) { + increment_tcp_syn_stats(); + + __u8 fingerprint[14] = {0}; + generate_tcp_fingerprint(tcph, (void *)(long)ctx->data_end, iph->ttl, + fingerprint); + + if (is_tcp_fingerprint_blocked(fingerprint)) { + increment_tcp_fingerprint_blocks_ipv4(); + increment_total_packets_dropped(); + increment_dropped_ipv4_address(iph->saddr); + // bpf_printk("XDP: BLOCKED TCP fingerprint from IPv4 %pI4:%d - + // FP:%s", + // &iph->saddr, bpf_ntohs(tcph->source), fingerprint); + return XDP_DROP; + } + + record_tcp_fingerprint(iph->saddr, tcph->source, tcph, + (void *)(long)ctx->data_end, iph->ttl); + + // bpf_printk("TCP_FP: New IPv4 fingerprint from %pI4:%d - TTL:%d", + // &iph->saddr, bpf_ntohs(tcph->source), iph->ttl); + } + } else if (ip6h) { + __u8 fingerprint[14] = {0}; + generate_tcp_fingerprint(tcph, (void *)(long)ctx->data_end, + ip6h->hop_limit, fingerprint); + + if (is_tcp_fingerprint_blocked_v6(fingerprint)) { + increment_tcp_fingerprint_blocks_ipv6(); + increment_total_packets_dropped(); + increment_dropped_ipv6_address(ip6h->saddr); + // bpf_printk("XDP: BLOCKED TCP fingerprint from IPv6 %pI6:%d - + // FP:%s", + // &ip6h->saddr, bpf_ntohs(tcph->source), fingerprint); + return XDP_DROP; + } + + struct tcp_fingerprint_key_v6 key = {0}; + struct tcp_fingerprint_data data = {0}; + __u64 timestamp = bpf_ktime_get_ns(); + + copy_ipv6_addr_as_array(&key.src_ip, &ip6h->saddr); + key.src_port = tcph->source; + + __builtin_memcpy(key.fingerprint, fingerprint, sizeof(fingerprint)); + + struct tcp_fingerprint_data *existing = + bpf_map_lookup_elem(&tcp_fingerprints_v6, &key); + if (existing) { + data.first_seen = existing->first_seen; + data.last_seen = timestamp; + data.packet_count = existing->packet_count + 1; + data.ttl = existing->ttl; + data.mss = existing->mss; + data.window_size = existing->window_size; + data.window_scale = existing->window_scale; + data.options_len = existing->options_len; + + __builtin_memcpy(data.options, existing->options, + TCP_FP_MAX_OPTION_LEN); + bpf_map_update_elem(&tcp_fingerprints_v6, &key, &data, BPF_ANY); + + } else { + data.first_seen = timestamp; + data.last_seen = timestamp; + data.packet_count = 1; + data.ttl = ip6h->hop_limit; + data.window_size = bpf_ntohs(tcph->window); + + parse_tcp_mss_wscale(tcph, (void *)(long)ctx->data_end, &data.mss, + &data.window_scale); + + bpf_map_update_elem(&tcp_fingerprints_v6, &key, &data, BPF_ANY); + increment_unique_fingerprints(); + + // Log new IPv6 TCP fingerprint + // bpf_printk("TCP_FP: New IPv6 fingerprint from %pI6:%d - TTL:%d MSS:%d + // " + // "WS:%d Window:%d", + // &ip6h->saddr, bpf_ntohs(tcph->source), data.ttl, data.mss, + // data.window_scale, data.window_size); + } + } + } + return XDP_PASS; +} \ No newline at end of file diff --git a/src/security/firewall/bpf/xdp.bpf.c b/src/security/firewall/bpf/xdp.bpf.c new file mode 100644 index 0000000..cacdec0 --- /dev/null +++ b/src/security/firewall/bpf/xdp.bpf.c @@ -0,0 +1,533 @@ +#include "common.h" + +#include "lib/firewall.h" +#include "lib/helper.h" +#include "lib/tcp_fingerprinting.h" +#include "ratelimit.h" +#include "vmlinux.h" +#include "xdp_maps.h" + +SEC("xdp") +int xdp_pipeline(struct xdp_md *ctx) { + void *data_end = (void *)(long)ctx->data_end; + void *cursor = (void *)(long)ctx->data; + + // Debug: Count all packets + __u32 zero = 0; + __u32 *packet_count = bpf_map_lookup_elem(&total_packets_processed, &zero); + if (packet_count) { + __sync_fetch_and_add(packet_count, 1); + } + + struct ethhdr *eth = parse_and_advance(&cursor, data_end, sizeof(*eth)); + if (!eth) { + return XDP_PASS; + } + + struct iphdr *iph = NULL; + struct ipv6hdr *ip6h = NULL; + if (eth->h_proto == bpf_htons(ETH_P_IP)) { + iph = parse_and_advance(&cursor, data_end, sizeof(*iph)); + if (!iph) { + return XDP_PASS; + } + + } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) { + ip6h = parse_and_advance(&cursor, data_end, sizeof(*ip6h)); + if (!ip6h) { + return XDP_PASS; + } + } + + struct tcphdr *tcph = NULL; + struct udphdr *udph = NULL; + if ((ip6h && ip6h->nexthdr == IPPROTO_UDP) || + (iph && iph->protocol == IPPROTO_UDP)) { + udph = parse_and_advance(&cursor, data_end, sizeof(*udph)); + if (!udph) { + return XDP_PASS; + } + } else if ((ip6h && ip6h->nexthdr == IPPROTO_TCP) || + (iph && iph->protocol == IPPROTO_TCP)) { + tcph = parse_and_advance(&cursor, data_end, sizeof(*tcph)); + if (!tcph) { + return XDP_PASS; + } + } + + if (xdp_firewall(iph, ip6h) == XDP_DROP) + return XDP_DROP; + + if (xdp_portban(iph, tcph, udph) == XDP_DROP) + return XDP_DROP; + + if (xdp_tcp_fingerprinting(ctx, iph, ip6h, tcph) == XDP_DROP) + return XDP_DROP; + + if (xdp_ratelimit(iph, ip6h, tcph) == XDP_DROP) + return XDP_DROP; + + increment_total_packets_processed(); + + return XDP_PASS; +} + +// SEC("xdp") +// int arxignis_xdp_filter(struct xdp_md *ctx) { +// // This filter is designed to only block incoming traffic +// // It should be attached only to ingress hooks, not egress +// // The filtering logic below blocks packets based on source IP addresses +// // +// // IP Version Support: +// // - Supports IPv4-only, IPv6-only, and hybrid (both) modes +// // - Note: XDP requires IPv6 to be enabled at kernel level for attachment, +// // even when processing only IPv4 packets. This is a kernel limitation. +// // - The BPF program processes both IPv4 and IPv6 packets based on the +// // ethernet protocol type (ETH_P_IP for IPv4, ETH_P_IPV6 for IPv6) + +// void *data_end = (void *)(long)ctx->data_end; +// void *cursor = (void *)(long)ctx->data; + +// // Debug: Count all packets +// __u32 zero = 0; +// __u32 *packet_count = bpf_map_lookup_elem(&total_packets_processed, &zero); +// if (packet_count) { +// __sync_fetch_and_add(packet_count, 1); +// } + +// struct ethhdr *eth = parse_and_advance(&cursor, data_end, sizeof(*eth)); +// if (!eth) +// return XDP_PASS; + +// __u16 h_proto = eth->h_proto; + +// // Increment total packets processed counter +// increment_total_packets_processed(); + +// if (h_proto == bpf_htons(ETH_P_IP)) { +// struct iphdr *iph = parse_and_advance(&cursor, data_end, sizeof(*iph)); +// if (!iph) +// return XDP_PASS; + +// struct lpm_key key = { +// .prefixlen = 32, +// .addr = iph->saddr, +// }; + +// if (bpf_map_lookup_elem(&banned_ips, &key)) { +// increment_ipv4_banned_stats(); +// increment_total_packets_dropped(); +// increment_dropped_ipv4_address(iph->saddr); +// // bpf_printk("XDP: BLOCKED incoming permanently banned IPv4 %pI4", +// // &iph->saddr); +// return XDP_DROP; +// } + +// if (bpf_map_lookup_elem(&recently_banned_ips, &key)) { +// increment_ipv4_recently_banned_stats(); +// // Block UDP and ICMP from recently banned IPs, but allow DNS +// if (iph->protocol == IPPROTO_UDP) { +// struct udphdr *udph = +// parse_and_advance(&cursor, data_end, sizeof(*udph)); +// if (udph && udph->dest == bpf_htons(53)) { +// return XDP_PASS; // Allow DNS responses +// } +// // Block other UDP traffic +// ip_flag_t one = 1; +// bpf_map_update_elem(&banned_ips, &key, &one, BPF_ANY); +// bpf_map_delete_elem(&recently_banned_ips, &key); +// increment_total_packets_dropped(); +// increment_dropped_ipv4_address(iph->saddr); +// // bpf_printk("XDP: BLOCKED incoming UDP from recently banned IPv4 +// %pI4, +// // promoted to permanent ban", &iph->saddr); +// return XDP_DROP; +// } +// if (iph->protocol == IPPROTO_ICMP) { +// ip_flag_t one = 1; +// bpf_map_update_elem(&banned_ips, &key, &one, BPF_ANY); +// bpf_map_delete_elem(&recently_banned_ips, &key); +// increment_total_packets_dropped(); +// increment_dropped_ipv4_address(iph->saddr); +// // bpf_printk("XDP: BLOCKED incoming ICMP from recently banned IPv4 +// // %pI4, promoted to permanent ban", &iph->saddr); +// return XDP_DROP; +// } +// // For TCP, only promote to banned on FIN/RST +// if (iph->protocol == IPPROTO_TCP) { +// bpf_printk("tcp asd"); +// struct tcphdr *tcph = +// parse_and_advance(&cursor, data_end, sizeof(*tcph)); +// if (tcph) { +// // Perform TCP fingerprinting ONLY on SYN packets (not SYN-ACK) +// // This ensures we capture the initial handshake with MSS/WSCALE +// if (tcph->syn && !tcph->ack) { +// increment_tcp_syn_stats(); + +// if (ipv4_syn_ratelimit(iph->saddr) == XDP_DROP) { +// return XDP_DROP; +// } + +// // Generate fingerprint to check if blocked +// __u8 fingerprint[14] = {0}; +// generate_tcp_fingerprint(tcph, data_end, iph->ttl, fingerprint); + +// // Check if this TCP fingerprint is blocked +// if (is_tcp_fingerprint_blocked(fingerprint)) { +// increment_tcp_fingerprint_blocks_ipv4(); +// increment_total_packets_dropped(); +// increment_dropped_ipv4_address(iph->saddr); +// return XDP_DROP; +// } + +// record_tcp_fingerprint(iph->saddr, tcph->source, tcph, data_end, +// iph->ttl); +// } + +// if (tcph->fin || tcph->rst) { +// ip_flag_t one = 1; +// bpf_map_update_elem(&banned_ips, &key, &one, BPF_ANY); +// bpf_map_delete_elem(&recently_banned_ips, &key); +// increment_total_packets_dropped(); +// increment_dropped_ipv4_address(iph->saddr); +// // bpf_printk("XDP: TCP FIN/RST from incoming recently banned +// IPv4 +// // %pI4, promoted to permanent ban", &iph->saddr); +// } +// } +// } +// return XDP_PASS; +// } + +// // Perform TCP fingerprinting ONLY on SYN packets +// if (iph->protocol == IPPROTO_TCP) { +// struct tcphdr *tcph = parse_and_advance(&cursor, data_end, +// sizeof(*tcph)); if (tcph) { +// // Only fingerprint SYN packets (not SYN-ACK) to capture MSS/WSCALE +// if (tcph->syn && !tcph->ack) { +// increment_tcp_syn_stats(); + +// // Generate fingerprint to check if blocked +// __u8 fingerprint[14] = {0}; +// generate_tcp_fingerprint(tcph, data_end, iph->ttl, fingerprint); + +// // Check if this TCP fingerprint is blocked +// if (is_tcp_fingerprint_blocked(fingerprint)) { +// increment_tcp_fingerprint_blocks_ipv4(); +// increment_total_packets_dropped(); +// increment_dropped_ipv4_address(iph->saddr); +// // bpf_printk("XDP: BLOCKED TCP fingerprint from IPv4 %pI4:%d - +// // FP:%s", +// // &iph->saddr, bpf_ntohs(tcph->source), fingerprint); +// return XDP_DROP; +// } + +// // Record fingerprint for monitoring +// record_tcp_fingerprint(iph->saddr, tcph->source, tcph, data_end, +// iph->ttl); +// } +// } +// } + +// // Check IPv4 port bans +// if (iph->protocol == IPPROTO_TCP || iph->protocol == IPPROTO_UDP) { +// void *port_cursor = cursor; +// __be16 src_port = 0; +// __be16 dst_port = 0; + +// if (iph->protocol == IPPROTO_TCP) { +// struct tcphdr *tcph_tmp = +// parse_and_advance(&port_cursor, data_end, sizeof(*tcph_tmp)); +// if (!tcph_tmp) +// return XDP_PASS; +// src_port = tcph_tmp->source; +// dst_port = tcph_tmp->dest; +// } else { +// struct udphdr *udph_tmp = +// parse_and_advance(&port_cursor, data_end, sizeof(*udph_tmp)); +// if (!udph_tmp) +// return XDP_PASS; +// src_port = udph_tmp->source; +// dst_port = udph_tmp->dest; +// } + +// struct src_port_key_v4 inbound_key = { +// .addr = iph->saddr, +// .port = src_port, +// }; + +// if (bpf_map_lookup_elem(&banned_inbound_ipv4_address_ports, +// &inbound_key)) { +// increment_total_packets_dropped(); +// increment_dropped_ipv4_address(iph->saddr); +// return XDP_DROP; +// } + +// struct src_port_key_v4 outbound_key = { +// .addr = iph->daddr, +// .port = dst_port, +// }; + +// if (bpf_map_lookup_elem(&banned_outbound_ipv4_address_ports, +// &outbound_key)) { +// increment_total_packets_dropped(); +// increment_dropped_ipv4_address(iph->daddr); +// return XDP_DROP; +// } +// } + +// return XDP_PASS; +// } else if (h_proto == bpf_htons(ETH_P_IPV6)) { +// struct ipv6hdr *ip6h = parse_and_advance(&cursor, data_end, +// sizeof(*ip6h)); if (!ip6h) +// return XDP_PASS; + +// // Always allow DNS traffic (UDP port 53) to pass through +// if (ip6h->nexthdr == IPPROTO_UDP) { +// struct udphdr *udph = parse_and_advance(&cursor, data_end, +// sizeof(*udph)); if (udph && +// (udph->dest == bpf_htons(53) || udph->source == bpf_htons(53))) { +// return XDP_PASS; // Always allow DNS traffic +// } +// } + +// // Check banned/recently banned maps by source IPv6 +// struct lpm_key_v6 key6 = { +// .prefixlen = 128, +// }; +// // Manual copy for BPF compatibility +// __u8 *src_addr = (__u8 *)&ip6h->saddr; +// #pragma unroll +// for (int i = 0; i < 16; i++) { +// key6.addr[i] = src_addr[i]; +// } + +// if (bpf_map_lookup_elem(&banned_ips_v6, &key6)) { +// increment_ipv6_banned_stats(); +// increment_total_packets_dropped(); +// increment_dropped_ipv6_address(ip6h->saddr); +// // bpf_printk("XDP: BLOCKED incoming permanently banned IPv6"); +// return XDP_DROP; +// } + +// if (bpf_map_lookup_elem(&recently_banned_ips_v6, &key6)) { +// increment_ipv6_recently_banned_stats(); +// // Block UDP and ICMP from recently banned IPv6 IPs, but allow DNS +// if (ip6h->nexthdr == IPPROTO_UDP) { +// struct udphdr *udph = +// parse_and_advance(&cursor, data_end, sizeof(*udph)); +// if (udph && udph->dest == bpf_htons(53)) { +// return XDP_PASS; // Allow DNS responses +// } +// // Block other UDP traffic +// ip_flag_t one = 1; +// bpf_map_update_elem(&banned_ips_v6, &key6, &one, BPF_ANY); +// bpf_map_delete_elem(&recently_banned_ips_v6, &key6); +// increment_total_packets_dropped(); +// increment_dropped_ipv6_address(ip6h->saddr); +// // bpf_printk("XDP: BLOCKED incoming UDP from recently banned IPv6, +// // promoted to permanent ban"); +// return XDP_DROP; +// } +// if (ip6h->nexthdr == 58) { // 58 = IPPROTO_ICMPV6 +// ip_flag_t one = 1; +// bpf_map_update_elem(&banned_ips_v6, &key6, &one, BPF_ANY); +// bpf_map_delete_elem(&recently_banned_ips_v6, &key6); +// increment_total_packets_dropped(); +// increment_dropped_ipv6_address(ip6h->saddr); +// // bpf_printk("XDP: BLOCKED incoming ICMPv6 from recently banned +// IPv6, +// // promoted to permanent ban"); +// return XDP_DROP; +// } +// // For TCP, only promote to banned on FIN/RST +// if (ip6h->nexthdr == IPPROTO_TCP) { +// struct tcphdr *tcph = +// parse_and_advance(&cursor, data_end, sizeof(*tcph)); +// if (tcph) { +// if (tcph->fin || tcph->rst) { +// ip_flag_t one = 1; +// bpf_map_update_elem(&banned_ips_v6, &key6, &one, BPF_ANY); +// bpf_map_delete_elem(&recently_banned_ips_v6, &key6); +// increment_total_packets_dropped(); +// increment_dropped_ipv6_address(ip6h->saddr); +// // bpf_printk("XDP: TCP FIN/RST from incoming recently banned +// IPv6, +// // promoted to permanent ban"); +// } +// } +// } +// return XDP_PASS; // Allow if recently banned +// } + +// // Perform TCP fingerprinting on IPv6 TCP packets +// if (ip6h->nexthdr == IPPROTO_TCP) { +// struct tcphdr *tcph = parse_and_advance(&cursor, data_end, +// sizeof(*tcph)); if (tcph) { +// // Perform TCP fingerprinting ONLY on SYN packets (not SYN-ACK) +// // This ensures we capture the initial handshake with MSS/WSCALE +// if (tcph->syn && !tcph->ack) { +// // Skip IPv6 localhost traffic to reduce noise +// // Check for ::1 (IPv6 localhost) - manual comparison +// __u8 *src_addr = (__u8 *)&ip6h->saddr; +// bool is_localhost = true; + +// // Check first 15 bytes are zero +// #pragma unroll +// for (int i = 0; i < 15; i++) { +// if (src_addr[i] != 0) { +// is_localhost = false; +// break; +// } +// } +// // Check last byte is 1 +// if (is_localhost && src_addr[15] == 1) { +// return XDP_PASS; +// } + +// if (ipv6_syn_ratelimit(src_addr) == XDP_DROP) { +// return XDP_DROP; +// } + +// // Extract TTL from IPv6 hop limit +// __u16 ttl = ip6h->hop_limit; + +// // Generate fingerprint to check if blocked +// __u8 fingerprint[14] = {0}; +// generate_tcp_fingerprint(tcph, data_end, ttl, fingerprint); + +// // Check if this TCP fingerprint is blocked +// if (is_tcp_fingerprint_blocked_v6(fingerprint)) { +// increment_tcp_fingerprint_blocks_ipv6(); +// increment_total_packets_dropped(); +// increment_dropped_ipv6_address(ip6h->saddr); +// // bpf_printk("XDP: BLOCKED TCP fingerprint from IPv6 %pI6:%d - +// // FP:%s", +// // &ip6h->saddr, bpf_ntohs(tcph->source), +// fingerprint); return XDP_DROP; +// } + +// // Create IPv6 fingerprint key with full 128-bit address +// struct tcp_fingerprint_key_v6 key = {0}; +// struct tcp_fingerprint_data data = {0}; +// __u64 timestamp = bpf_ktime_get_ns(); + +// // Copy full IPv6 address (16 bytes) - manual copy for BPF +// #pragma unroll +// for (int i = 0; i < 16; i++) { +// key.src_ip[i] = src_addr[i]; +// } +// key.src_port = tcph->source; + +// // Copy fingerprint to key +// #pragma unroll +// for (int i = 0; i < 14; i++) { +// key.fingerprint[i] = fingerprint[i]; +// } + +// // Check if fingerprint already exists in IPv6 map +// struct tcp_fingerprint_data *existing = +// bpf_map_lookup_elem(&tcp_fingerprints_v6, &key); +// if (existing) { +// // Update existing entry - must copy to local variable first +// data.first_seen = existing->first_seen; +// data.last_seen = timestamp; +// data.packet_count = existing->packet_count + 1; +// data.ttl = existing->ttl; +// data.mss = existing->mss; +// data.window_size = existing->window_size; +// data.window_scale = existing->window_scale; +// data.options_len = existing->options_len; + +// // Copy options array +// #pragma unroll +// for (int i = 0; i < TCP_FP_MAX_OPTION_LEN; i++) { +// data.options[i] = existing->options[i]; +// } + +// bpf_map_update_elem(&tcp_fingerprints_v6, &key, &data, BPF_ANY); +// } else { +// // Create new entry +// data.first_seen = timestamp; +// data.last_seen = timestamp; +// data.packet_count = 1; +// data.ttl = ttl; +// data.window_size = bpf_ntohs(tcph->window); + +// // Extract MSS and window scale from options +// parse_tcp_mss_wscale(tcph, data_end, &data.mss, +// &data.window_scale); + +// bpf_map_update_elem(&tcp_fingerprints_v6, &key, &data, BPF_ANY); +// increment_unique_fingerprints(); + +// // Log new IPv6 TCP fingerprint +// // bpf_printk("TCP_FP: New IPv6 fingerprint from %pI6:%d - TTL:%d +// // MSS:%d WS:%d Window:%d", +// // &ip6h->saddr, bpf_ntohs(tcph->source), ttl, +// data.mss, +// // data.window_scale, data.window_size); +// } +// } +// } +// } + +// // Check IPv6 port bans +// if (ip6h->nexthdr == IPPROTO_TCP || ip6h->nexthdr == IPPROTO_UDP) { +// void *port_cursor = cursor; +// __be16 src_port = 0; +// __be16 dst_port = 0; + +// if (ip6h->nexthdr == IPPROTO_TCP) { +// struct tcphdr *tcph_tmp = +// parse_and_advance(&port_cursor, data_end, sizeof(*tcph_tmp)); +// if (!tcph_tmp) +// return XDP_PASS; +// src_port = tcph_tmp->source; +// dst_port = tcph_tmp->dest; +// } else { +// struct udphdr *udph_tmp = +// parse_and_advance(&port_cursor, data_end, sizeof(*udph_tmp)); +// if (!udph_tmp) +// return XDP_PASS; +// src_port = udph_tmp->source; +// dst_port = udph_tmp->dest; +// } + +// struct src_port_key_v6 inbound_key6 = {0}; +// #pragma unroll +// for (int i = 0; i < 16; i++) { +// inbound_key6.addr[i] = ((__u8 *)&ip6h->saddr)[i]; +// } +// inbound_key6.port = src_port; + +// if (bpf_map_lookup_elem(&banned_inbound_ipv6_address_ports, +// &inbound_key6)) { +// increment_total_packets_dropped(); +// increment_dropped_ipv6_address(ip6h->saddr); +// return XDP_DROP; +// } + +// struct src_port_key_v6 outbound_key6 = {0}; +// #pragma unroll +// for (int i = 0; i < 16; i++) { +// outbound_key6.addr[i] = ((__u8 *)&ip6h->daddr)[i]; +// } +// outbound_key6.port = dst_port; + +// if (bpf_map_lookup_elem(&banned_outbound_ipv6_address_ports, +// &outbound_key6)) { +// increment_total_packets_dropped(); +// increment_dropped_ipv6_address(ip6h->daddr); +// return XDP_DROP; +// } +// } + +// return XDP_PASS; +// } + +// return XDP_PASS; +// // return XDP_ABORTED; +// } + +char _license[] SEC("license") = "GPL"; diff --git a/src/security/firewall/bpf/xdp_maps.h b/src/security/firewall/bpf/xdp_maps.h new file mode 100644 index 0000000..d3a9689 --- /dev/null +++ b/src/security/firewall/bpf/xdp_maps.h @@ -0,0 +1,246 @@ +#pragma once + +#include "common.h" +#include "vmlinux.h" + +#define TCP_FINGERPRINT_MAX_ENTRIES 10000 + +// IPv4 maps: permanently banned and recently banned +struct { + __uint(type, BPF_MAP_TYPE_LPM_TRIE); + __uint(max_entries, CITADEL_IP_MAP_MAX); + __uint(map_flags, BPF_F_NO_PREALLOC); + __type(key, struct lpm_key); // IPv4 address in network byte order + __type(value, ip_flag_t); // presence flag (1) +} banned_ips SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_LPM_TRIE); + __uint(max_entries, CITADEL_IP_MAP_MAX); + __uint(map_flags, BPF_F_NO_PREALLOC); + __type(key, struct lpm_key); + __type(value, ip_flag_t); +} recently_banned_ips SEC(".maps"); + +// IPv6 maps: permanently banned and recently banned +struct { + __uint(type, BPF_MAP_TYPE_LPM_TRIE); + __uint(max_entries, CITADEL_IP_MAP_MAX); + __uint(map_flags, BPF_F_NO_PREALLOC); + __type(key, struct lpm_key_v6); + __type(value, ip_flag_t); +} banned_ips_v6 SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_LPM_TRIE); + __uint(max_entries, CITADEL_IP_MAP_MAX); + __uint(map_flags, BPF_F_NO_PREALLOC); + __type(key, struct lpm_key_v6); + __type(value, ip_flag_t); +} recently_banned_ips_v6 SEC(".maps"); + +// Remove dynptr helpers, not used in XDP manual parsing +// extern int bpf_dynptr_from_skb(struct __sk_buff *skb, __u64 flags, +// struct bpf_dynptr *ptr__uninit) __ksym; +// extern void *bpf_dynptr_slice(const struct bpf_dynptr *ptr, uint32_t offset, +// void *buffer, uint32_t buffer__sz) __ksym; + +volatile int shootdowns = 0; + +// Statistics maps for tracking access rule hits +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} ipv4_banned_stats SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} ipv4_recently_banned_stats SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} ipv6_banned_stats SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} ipv6_recently_banned_stats SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} total_packets_processed SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} total_packets_dropped SEC(".maps"); + +// TCP fingerprinting maps +struct { + __uint(type, BPF_MAP_TYPE_LRU_HASH); + __uint(max_entries, TCP_FINGERPRINT_MAX_ENTRIES); + __type(key, struct tcp_fingerprint_key); + __type(value, struct tcp_fingerprint_data); +} tcp_fingerprints SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_LRU_HASH); + __uint(max_entries, TCP_FINGERPRINT_MAX_ENTRIES); + __type(key, struct tcp_fingerprint_key_v6); + __type(value, struct tcp_fingerprint_data); +} tcp_fingerprints_v6 SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, struct tcp_syn_stats); +} tcp_syn_stats SEC(".maps"); + +// Blocked TCP fingerprint maps (only store the fingerprint string, not per-IP) +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, 10000); // Store up to 10k blocked fingerprint patterns + __type(key, __u8[14]); // TCP fingerprint string (14 bytes) + __type(value, __u8); // Flag (1 = blocked) +} blocked_tcp_fingerprints SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, 10000); + __type(key, __u8[14]); // TCP fingerprint string (14 bytes) + __type(value, __u8); // Flag (1 = blocked) +} blocked_tcp_fingerprints_v6 SEC(".maps"); + +// Statistics for TCP fingerprint blocks +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} tcp_fingerprint_blocks_ipv4 SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); +} tcp_fingerprint_blocks_ipv6 SEC(".maps"); + +// Maps to track dropped IP addresses with counters +struct { + __uint(type, BPF_MAP_TYPE_LRU_HASH); + __uint(max_entries, 1000); // Track up to 1000 unique dropped IPs + __type(key, __be32); // IPv4 address + __type(value, __u64); // Drop count +} dropped_ipv4_addresses SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_LRU_HASH); + __uint(max_entries, 1000); // Track up to 1000 unique dropped IPv6s + __type(key, __u8[16]); // IPv6 address + __type(value, __u64); // Drop count +} dropped_ipv6_addresses SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, 4096); + __type(key, struct src_port_key_v4); + __type(value, __u8); +} banned_inbound_ipv4_address_ports SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, 4096); + __type(key, struct src_port_key_v6); + __type(value, __u8); +} banned_inbound_ipv6_address_ports SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, 4096); + __type(key, struct src_port_key_v4); + __type(value, __u8); +} banned_outbound_ipv4_address_ports SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, 4096); + __type(key, struct src_port_key_v6); + __type(value, __u8); +} banned_outbound_ipv6_address_ports SEC(".maps"); + +// RATE LIMITER MAPS + +struct ratelimiter_metrics { + __u64 requests_dropped; + __u64 total_ipv4_dropped; + __u64 total_ipv6_dropped; +}; + +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, 1); + __type(key, __u8); + __type(value, struct ratelimiter_metrics); +} ratelimit_metrics_map SEC(".maps"); + +#define DEFAULT_BUCKET_ENTRIES 50000 + +struct ratelimit_bucket_value { + __u64 last_topup; + __u64 num_of_tokens; +}; + +struct { + __uint(type, BPF_MAP_TYPE_LRU_PERCPU_HASH); + __uint(max_entries, + DEFAULT_BUCKET_ENTRIES); // to be modified in userspace before loading + // the script + __type(key, __be32); + __type(value, struct ratelimit_bucket_value); +} ipv4_syn_bucket_store SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_LRU_PERCPU_HASH); + __uint(max_entries, + DEFAULT_BUCKET_ENTRIES); // to be modified in userspace before loading + // the script + __type(key, ipv6_addr); + __type(value, struct ratelimit_bucket_value); +} ipv6_syn_bucket_store SEC(".maps"); + +#define DEFAULT_RATELIMITER_WHITELIST_ENTRIES_IPV4 100 +#define DEFAULT_RATELIMITER_WHITELIST_ENTRIES_IPV6 100 + +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, DEFAULT_RATELIMITER_WHITELIST_ENTRIES_IPV4); + __type(key, __be32); + __type(value, __u8); +} ipv4_ratelimit_whitelist SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, DEFAULT_RATELIMITER_WHITELIST_ENTRIES_IPV6); + __type(key, ipv6_addr); + __type(value, __u8); +} ipv6_ratelimit_whitelist SEC(".maps"); + +// TODO: volumetric rate limiter, ratelimiter diagnostics, integration in the +// app. \ No newline at end of file diff --git a/src/security/firewall/mod.rs b/src/security/firewall/mod.rs index 91e4a5d..015d414 100644 --- a/src/security/firewall/mod.rs +++ b/src/security/firewall/mod.rs @@ -12,7 +12,7 @@ pub mod iptables; pub mod nftables; pub mod bpf { // Include the skeleton generated by build.rs into OUT_DIR at compile time - include!(concat!(env!("OUT_DIR"), "/filter.skel.rs")); + include!(concat!(env!("OUT_DIR"), "/xdp.skel.rs")); } pub use iptables::IptablesFirewall; pub use nftables::NftablesFirewall; @@ -92,11 +92,11 @@ pub trait Firewall { } pub struct SYNAPSEFirewall<'a> { - skel: &'a crate::security::firewall::bpf::FilterSkel<'a>, + skel: &'a crate::security::firewall::bpf::XdpSkel<'a>, } impl<'a> SYNAPSEFirewall<'a> { - pub fn new(skel: &'a crate::security::firewall::bpf::FilterSkel<'a>) -> Self { + pub fn new(skel: &'a crate::security::firewall::bpf::XdpSkel<'a>) -> Self { Self { skel } } diff --git a/src/security/mod.rs b/src/security/mod.rs index 0596232..fcc8d95 100644 --- a/src/security/mod.rs +++ b/src/security/mod.rs @@ -1,4 +1,6 @@ pub mod access_rules; #[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] pub mod firewall; +#[cfg(all(feature = "bpf", not(feature = "disable-bpf")))] +pub mod ratelimiter; pub mod waf; diff --git a/src/security/ratelimiter/mod.rs b/src/security/ratelimiter/mod.rs new file mode 100644 index 0000000..f59410a --- /dev/null +++ b/src/security/ratelimiter/mod.rs @@ -0,0 +1,102 @@ +use crate::{ + security::firewall::bpf::{OpenXdpSkel, XdpSkel}, + worker::config::XDPRateLimitConfig, +}; +use anyhow::Result; +use libbpf_rs::{MapCore, MapFlags}; +use serde::{Deserialize, Serialize}; +use std::net::{Ipv4Addr, Ipv6Addr}; + +pub struct XDPRateLimit<'a, 'b> { + skel: &'a mut XdpSkel<'b>, +} + +impl<'a, 'b> XDPRateLimit<'a, 'b> { + pub fn new(skel: &'a mut XdpSkel<'b>) -> Self { + Self { skel } + } + + pub fn set_request_per_sec(&mut self, request_per_seq: u64, burst_multiplier: f32) { + if let Some(data) = self.skel.maps.data_data.as_mut() { + data.ratelimiter_config.TOKENS_PER_REQUEST = 1; + data.ratelimiter_config.REFILL_RATE = request_per_seq; + data.ratelimiter_config.MAX_BUCKET_CAPACITY = + (request_per_seq as f32 * burst_multiplier) as u64; + } + } + + pub fn set_ratelimiter_status(&mut self, enabled: bool) { + if let Some(data) = self.skel.maps.data_data.as_mut() { + data.ratelimiter_config.ENABLED = enabled as u8; + } + } + + pub fn setup_from_config(&mut self, config: &XDPRateLimitConfig) { + self.set_request_per_sec(config.requests, config.burst_factor); + self.set_ratelimiter_status(config.enabled); + } + + pub fn add_ipv4_to_whitelist(&mut self, ip: Ipv4Addr) -> Result<()> { + let key = u32::from(ip).to_be_bytes(); + let value: u8 = 1; + self.skel + .maps + .ipv4_ratelimit_whitelist + .update(&key, &[value], MapFlags::ANY)?; + Ok(()) + } + + pub fn add_ipv6_to_whitelist(&mut self, ip: Ipv6Addr) -> Result<()> { + let key = ip.octets(); + let value: u8 = 1; + self.skel + .maps + .ipv6_ratelimit_whitelist + .update(&key, &[value], MapFlags::ANY)?; + Ok(()) + } + + pub fn remove_ipv4_from_whitelist(&mut self, ip: Ipv4Addr) -> Result<()> { + let key = u32::from(ip).to_be_bytes(); + self.skel.maps.ipv4_ratelimit_whitelist.delete(&key)?; + Ok(()) + } + + pub fn remove_ipv6_from_whitelist(&mut self, ip: Ipv6Addr) -> Result<()> { + let key = ip.octets(); + self.skel.maps.ipv6_ratelimit_whitelist.delete(&key)?; + Ok(()) + } + + pub fn ipv4_bucket_max_entrie(&self) -> u32 { + self.skel.maps.ipv4_syn_bucket_store.max_entries() + } + + pub fn ipv6_bucket_max_entrie(&self) -> u32 { + self.skel.maps.ipv6_syn_bucket_store.max_entries() + } +} + +// setup before loading the ebpf program + +pub fn set_bucket_map_size_ipv4<'b>( + skel: &mut OpenXdpSkel<'b>, + bucket_map_size: u32, +) -> Result<()> { + skel.maps + .ipv4_syn_bucket_store + .set_max_entries(bucket_map_size)?; + + Ok(()) +} + +pub fn set_bucket_map_size_ipv6<'b>( + skel: &mut OpenXdpSkel<'b>, + bucket_map_size: u32, +) -> Result<()> { + skel.maps + .ipv6_syn_bucket_store + .set_max_entries(bucket_map_size)?; + + Ok(()) +} diff --git a/src/security/waf/content_scanning/mod.rs b/src/security/waf/content_scanning/mod.rs new file mode 100644 index 0000000..b607526 --- /dev/null +++ b/src/security/waf/content_scanning/mod.rs @@ -0,0 +1,652 @@ +use anyhow::{Result, anyhow}; +use clamav_tcp::scan; +use serde::{Deserialize, Serialize}; +use std::io::Cursor; +use std::sync::{Arc, RwLock, OnceLock}; +use std::collections::HashMap; +use std::net::SocketAddr; +use hyper::http::request::Parts; +use wirefilter::{ExecutionContext, Scheme, Filter}; +use bytes::Bytes; +use multer::Multipart; + +/// Content scanning configuration +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ContentScanningConfig { + /// Enable or disable content scanning + pub enabled: bool, + /// ClamAV server address (e.g., "localhost:3310") + pub clamav_server: String, + /// Maximum file size to scan in bytes (default: 10MB) + pub max_file_size: usize, + /// Content types to scan (empty means scan all) + pub scan_content_types: Vec, + /// Skip scanning for specific file extensions + pub skip_extensions: Vec, + /// Wirefilter expression to determine when to scan + #[serde(default = "default_scan_expression")] + pub scan_expression: String, +} + +fn default_scan_expression() -> String { + "http.request.method eq \"POST\" or http.request.method eq \"PUT\"".to_string() +} + +impl Default for ContentScanningConfig { + fn default() -> Self { + Self { + enabled: false, + clamav_server: "localhost:3310".to_string(), + max_file_size: 10 * 1024 * 1024, // 10MB + scan_content_types: vec![ + "text/html".to_string(), + "application/x-www-form-urlencoded".to_string(), + "multipart/form-data".to_string(), + "application/json".to_string(), + "text/plain".to_string(), + ], + skip_extensions: vec![], + scan_expression: default_scan_expression(), + } + } +} + +/// Content scanning result +#[derive(Debug, Clone)] +pub struct ScanResult { + /// Whether malware was detected + pub malware_detected: bool, + /// Malware signature name if detected + pub signature: Option, + /// Error message if scanning failed + pub error: Option, +} + +/// Content scanner implementation +pub struct ContentScanner { + config: Arc>, + scheme: Arc, + filter: Arc>>, +} + +/// Extract boundary from Content-Type header for multipart content +pub fn extract_multipart_boundary(content_type: &str) -> Option { + // Content-Type format: multipart/form-data; boundary=----WebKitFormBoundary... + if !content_type.to_lowercase().contains("multipart/") { + return None; + } + + for part in content_type.split(';') { + let trimmed = part.trim(); + let lower = trimmed.to_lowercase(); + if lower.starts_with("boundary=") { + // Find the actual position of "boundary=" in the original string (case-insensitive) + if let Some(eq_pos) = trimmed.to_lowercase().find("boundary=") { + let boundary = trimmed[eq_pos + 9..].trim(); + // Remove quotes if present + let boundary = boundary.trim_matches('"').trim_matches('\''); + return Some(boundary.to_string()); + } + } + } + + None +} + +impl ContentScanner { + /// Create a new content scanner + pub fn new(config: ContentScanningConfig) -> Self { + let scheme = Self::create_scheme(); + let filter = Self::compile_filter(&scheme, &config.scan_expression); + + Self { + config: Arc::new(RwLock::new(config)), + scheme: Arc::new(scheme), + filter: Arc::new(RwLock::new(filter)), + } + } + + /// Create the wirefilter scheme for content scanning + fn create_scheme() -> Scheme { + let builder = wirefilter::Scheme! { + http.request.method: Bytes, + http.request.path: Bytes, + http.request.content_type: Bytes, + http.request.content_length: Int, + }; + builder.build() + } + + /// Compile the scan expression filter + fn compile_filter(scheme: &Scheme, expression: &str) -> Option { + if expression.is_empty() { + return None; + } + + match scheme.parse(expression) { + Ok(ast) => Some(ast.compile()), + Err(e) => { + log::error!("Failed to compile content scanning expression '{}': {}", expression, e); + None + } + } + } + + /// Update scanner configuration + pub fn update_config(&self, config: ContentScanningConfig) { + let new_filter = Self::compile_filter(&self.scheme, &config.scan_expression); + + if let Ok(mut guard) = self.config.write() { + *guard = config; + } + if let Ok(mut guard) = self.filter.write() { + *guard = new_filter; + } + } + + /// Check if content scanning should be performed for this request + pub fn should_scan(&self, req_parts: &Parts, body_bytes: &[u8], _peer_addr: SocketAddr) -> bool { + let config = match self.config.read() { + Ok(guard) => guard.clone(), + Err(_) => return false, + }; + + if !config.enabled { + log::debug!("Content scanning disabled"); + return false; + } + + log::debug!("Checking if should scan request: method={}, path={}, body_size={}", + req_parts.method, req_parts.uri.path(), body_bytes.len()); + + // Check wirefilter expression first + let filter_guard = match self.filter.read() { + Ok(guard) => guard, + Err(_) => return false, + }; + + if let Some(ref filter) = *filter_guard { + let mut ctx = ExecutionContext::new(&self.scheme); + + // Set request fields + let method = req_parts.method.as_str(); + let path = req_parts.uri.path(); + let content_type = req_parts.headers + .get("content-type") + .and_then(|h| h.to_str().ok()) + .unwrap_or(""); + let content_length = body_bytes.len() as i64; + + if ctx.set_field_value(self.scheme.get_field("http.request.method").unwrap(), method).is_err() { + return false; + } + if ctx.set_field_value(self.scheme.get_field("http.request.path").unwrap(), path).is_err() { + return false; + } + if ctx.set_field_value(self.scheme.get_field("http.request.content_type").unwrap(), content_type).is_err() { + return false; + } + if ctx.set_field_value(self.scheme.get_field("http.request.content_length").unwrap(), content_length).is_err() { + return false; + } + + // Execute filter + match filter.execute(&ctx) { + Ok(result) => { + if !result { + log::debug!("Skipping content scan: expression does not match"); + return false; + } else { + log::debug!("Expression matched, proceeding with content scan checks"); + log::debug!("Expression result: {:?}", result); + } + } + Err(e) => { + log::error!("Failed to execute content scanning expression: {}", e); + return false; + } + } + } else { + log::debug!("No scan expression configured, allowing scan"); + } + + // Check if body is too large + if body_bytes.len() > config.max_file_size { + log::debug!("Skipping content scan: body too large ({} bytes)", body_bytes.len()); + return false; + } + + // Check content type + if let Some(content_type) = req_parts.headers.get("content-type") { + if let Ok(content_type_str) = content_type.to_str() { + let content_type_lower = content_type_str.to_lowercase(); + + // If specific content types are configured, only scan those + if !config.scan_content_types.is_empty() { + let should_scan = config.scan_content_types.iter() + .any(|ct| content_type_lower.contains(ct)); + if !should_scan { + log::debug!("Skipping content scan: content type '{}' not in scan list: {:?}", + content_type_str, config.scan_content_types); + return false; + } else { + log::debug!("Content type '{}' matches scan list", content_type_str); + } + } + + // Skip certain content types + if content_type_lower.contains("image/") || + content_type_lower.contains("video/") || + content_type_lower.contains("audio/") { + log::debug!("Skipping content scan: binary content type {}", content_type_str); + return false; + } + } + } + + // Check file extension from URL path + if let Some(path) = req_parts.uri.path().split('/').last() { + if let Some(extension) = std::path::Path::new(path).extension() { + if let Some(ext_str) = extension.to_str() { + let ext_lower = format!(".{}", ext_str.to_lowercase()); + if config.skip_extensions.contains(&ext_lower) { + log::debug!("Skipping content scan: file extension {} in skip list", ext_lower); + return false; + } + } + } + } + + log::debug!("All checks passed, will scan content"); + true + } + + /// Scan content for malware + pub async fn scan_content(&self, body_bytes: &[u8]) -> Result { + let config = match self.config.read() { + Ok(guard) => guard.clone(), + Err(_) => return Err(anyhow!("Failed to read scanner config")), + }; + + if !config.enabled { + return Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }); + } + + self.scan_bytes(&config.clamav_server, body_bytes).await + } + + /// Internal method to scan bytes with ClamAV + async fn scan_bytes(&self, clamav_server: &str, data: &[u8]) -> Result { + // Create a cursor over the body bytes for scanning + let mut cursor = Cursor::new(data); + + // Perform the scan + match scan(clamav_server, &mut cursor, None) { + Ok(result) => { + // Check if malware was detected using the new API + if !result.is_infected { + Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }) + } else { + // Extract signature name from detected_infections + let signature = if !result.detected_infections.is_empty() { + Some(result.detected_infections.join(", ")) + } else { + None + }; + + Ok(ScanResult { + malware_detected: true, + signature, + error: None, + }) + } + } + Err(e) => { + log::error!("ClamAV scan failed: {}", e); + Ok(ScanResult { + malware_detected: false, + signature: None, + error: Some(format!("Scan failed: {}", e)), + }) + } + } + } + + /// Scan multipart content for malware by parsing parts and scanning each individually + pub async fn scan_multipart_content(&self, body_bytes: &[u8], boundary: &str) -> Result { + let config = match self.config.read() { + Ok(guard) => guard.clone(), + Err(_) => return Err(anyhow!("Failed to read scanner config")), + }; + + if !config.enabled { + return Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }); + } + + log::debug!("Parsing multipart body with boundary: {}", boundary); + + // Create a multipart parser + let stream = futures::stream::once(async move { + Result::::Ok(Bytes::copy_from_slice(body_bytes)) + }); + + let mut multipart = Multipart::new(stream, boundary); + + let mut parts_scanned = 0; + let mut parts_failed = 0; + + // Iterate over each part in the multipart body + while let Some(field) = multipart.next_field().await.map_err(|e| anyhow!("Failed to read multipart field: {}", e))? { + let field_name = field.name().unwrap_or("").to_string(); + let field_filename = field.file_name().map(|s| s.to_string()); + let field_content_type = field.content_type().map(|m| m.to_string()); + + log::debug!("Scanning multipart field: name={}, filename={:?}, content_type={:?}", + field_name, field_filename, field_content_type); + + // Read the entire field into bytes + let field_bytes = field.bytes().await.map_err(|e| anyhow!("Failed to read field bytes: {}", e))?; + + // Skip empty fields + if field_bytes.is_empty() { + log::debug!("Skipping empty multipart field: {}", field_name); + continue; + } + + // Check if field size exceeds max_file_size + if field_bytes.len() > config.max_file_size { + log::debug!("Skipping multipart field '{}': size {} exceeds max_file_size {}", + field_name, field_bytes.len(), config.max_file_size); + continue; + } + + parts_scanned += 1; + + // Scan this part + match self.scan_bytes(&config.clamav_server, &field_bytes).await { + Ok(result) => { + if result.malware_detected { + log::info!("Malware detected in multipart field '{}' (filename: {:?}): signature {:?}", + field_name, field_filename, result.signature); + + // Return immediately on first malware detection + return Ok(ScanResult { + malware_detected: true, + signature: result.signature.map(|s| format!("{}:{}", field_name, s)), + error: None, + }); + } + } + Err(e) => { + log::warn!("Failed to scan multipart field '{}': {}", field_name, e); + parts_failed += 1; + } + } + } + + log::debug!("Multipart scan complete: {} parts scanned, {} failed", parts_scanned, parts_failed); + + // If all parts failed to scan, return an error + if parts_scanned > 0 && parts_failed == parts_scanned { + return Ok(ScanResult { + malware_detected: false, + signature: None, + error: Some(format!("All {} multipart parts failed to scan", parts_failed)), + }); + } + + // No malware detected + Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }) + } + + /// Scan HTML form data for malware + pub async fn scan_form_data(&self, form_data: &HashMap) -> Result { + let config = match self.config.read() { + Ok(guard) => guard.clone(), + Err(_) => return Err(anyhow!("Failed to read scanner config")), + }; + + if !config.enabled { + return Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }); + } + + // Combine all form values into a single string for scanning + let combined_data = form_data.values() + .map(|v| v.as_str()) + .collect::>() + .join("\n"); + + let mut cursor = Cursor::new(combined_data.as_bytes()); + + // Perform the scan + match scan(&config.clamav_server, &mut cursor, None) { + Ok(result) => { + // Check if malware was detected using the new API + if !result.is_infected { + Ok(ScanResult { + malware_detected: false, + signature: None, + error: None, + }) + } else { + // Extract signature name from detected_infections + let signature = if !result.detected_infections.is_empty() { + Some(result.detected_infections.join(", ")) + } else { + None + }; + + Ok(ScanResult { + malware_detected: true, + signature, + error: None, + }) + } + } + Err(e) => { + log::error!("ClamAV form data scan failed: {}", e); + Ok(ScanResult { + malware_detected: false, + signature: None, + error: Some(format!("Form data scan failed: {}", e)), + }) + } + } + } +} + +// Global content scanner instance +static CONTENT_SCANNER: OnceLock = OnceLock::new(); + +/// Get the global content scanner instance +pub fn get_global_content_scanner() -> Option<&'static ContentScanner> { + CONTENT_SCANNER.get() +} + +/// Set the global content scanner instance +pub fn set_global_content_scanner(scanner: ContentScanner) -> Result<()> { + CONTENT_SCANNER + .set(scanner) + .map_err(|_| anyhow!("Failed to initialize content scanner")) +} + +/// Initialize the global content scanner with default configuration +pub fn init_content_scanner(config: ContentScanningConfig) -> Result<()> { + let scanner = ContentScanner::new(config); + set_global_content_scanner(scanner)?; + log::info!("Content scanner initialized"); + Ok(()) +} + +/// Update the global content scanner configuration +pub fn update_content_scanner_config(config: ContentScanningConfig) -> Result<()> { + if let Some(scanner) = get_global_content_scanner() { + scanner.update_config(config); + log::info!("Content scanner configuration updated"); + Ok(()) + } else { + Err(anyhow!("Content scanner not initialized")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use hyper::http::request::Builder; + + #[test] + fn test_content_scanner_config_default() { + let config = ContentScanningConfig::default(); + assert!(!config.enabled); + assert_eq!(config.clamav_server, "localhost:3310"); + assert_eq!(config.max_file_size, 10 * 1024 * 1024); + assert!(!config.scan_content_types.is_empty()); + assert!(config.skip_extensions.is_empty()); + } + + #[test] + fn test_should_scan_disabled() { + use std::net::{Ipv4Addr, SocketAddr}; + + let config = ContentScanningConfig { + enabled: false, + ..Default::default() + }; + let scanner = ContentScanner::new(config); + + let req = Builder::new() + .method("POST") + .uri("http://example.com/test") + .body(()) + .unwrap(); + let (req_parts, _) = req.into_parts(); + let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); + + assert!(!scanner.should_scan(&req_parts, b"test content", peer_addr)); + } + + #[test] + fn test_should_scan_content_type_filter() { + use std::net::{Ipv4Addr, SocketAddr}; + + let config = ContentScanningConfig { + enabled: true, + scan_content_types: vec!["text/html".to_string()], + ..Default::default() + }; + let scanner = ContentScanner::new(config); + + let req = Builder::new() + .method("POST") + .uri("http://example.com/test") + .header("content-type", "text/html") + .body(()) + .unwrap(); + let (req_parts, _) = req.into_parts(); + let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); + + assert!(scanner.should_scan(&req_parts, b"test", peer_addr)); + + let req2 = Builder::new() + .method("POST") + .uri("http://example.com/test") + .header("content-type", "application/json") + .body(()) + .unwrap(); + let (req_parts2, _) = req2.into_parts(); + + assert!(!scanner.should_scan(&req_parts2, b"{\"test\": \"data\"}", peer_addr)); + } + + #[test] + fn test_should_scan_file_size_limit() { + use std::net::{Ipv4Addr, SocketAddr}; + + let config = ContentScanningConfig { + enabled: true, + max_file_size: 100, + ..Default::default() + }; + let scanner = ContentScanner::new(config); + + let req = Builder::new() + .method("POST") + .uri("http://example.com/test") + .body(()) + .unwrap(); + let (req_parts, _) = req.into_parts(); + let peer_addr = SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 8080); + + let small_content = b"small"; + let large_content = b"x".repeat(200); + + assert!(scanner.should_scan(&req_parts, small_content, peer_addr)); + assert!(!scanner.should_scan(&req_parts, &large_content, peer_addr)); + } + + #[test] + fn test_extract_multipart_boundary() { + // Test with standard format + let ct1 = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"; + assert_eq!( + extract_multipart_boundary(ct1), + Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) + ); + + // Test with quoted boundary + let ct2 = "multipart/form-data; boundary=\"----WebKitFormBoundary7MA4YWxkTrZu0gW\""; + assert_eq!( + extract_multipart_boundary(ct2), + Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) + ); + + // Test with spaces + let ct3 = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW "; + assert_eq!( + extract_multipart_boundary(ct3), + Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) + ); + + // Test with charset and boundary + let ct4 = "multipart/form-data; charset=utf-8; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"; + assert_eq!( + extract_multipart_boundary(ct4), + Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) + ); + + // Test non-multipart content type + let ct5 = "application/json"; + assert_eq!(extract_multipart_boundary(ct5), None); + + // Test missing boundary + let ct6 = "multipart/form-data"; + assert_eq!(extract_multipart_boundary(ct6), None); + + // Test mixed case + let ct7 = "Multipart/Form-Data; Boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"; + assert_eq!( + extract_multipart_boundary(ct7), + Some("----WebKitFormBoundary7MA4YWxkTrZu0gW".to_string()) + ); + } +} diff --git a/src/security/waf/wirefilter.rs b/src/security/waf/wirefilter.rs index 0d7c406..7746131 100644 --- a/src/security/waf/wirefilter.rs +++ b/src/security/waf/wirefilter.rs @@ -2,8 +2,9 @@ use std::collections::HashSet; use std::net::SocketAddr; use std::sync::{Arc, OnceLock, RwLock}; +use crate::security::ratelimiter::XDPRateLimit; use crate::security::waf::threat; -use crate::worker::config::{Config, fetch_config}; +use crate::worker::config::{AccessRuleConfig, Config, XDPRateLimitConfig, fetch_config}; use anyhow::Result; use anyhow::anyhow; use sha2::{Digest, Sha256}; @@ -871,6 +872,7 @@ pub async fn load_waf_rules(waf_rules: Vec) -> a country: vec![], ips: vec![], }, + config: AccessRuleConfig::default(), }, waf_rules: crate::worker::config::WafRules { rules: waf_rules }, content_scanning: diff --git a/src/storage/redis.rs b/src/storage/redis.rs index 2a0110b..4c037f3 100644 --- a/src/storage/redis.rs +++ b/src/storage/redis.rs @@ -1,399 +1,399 @@ -use anyhow::{Context, Result}; -use redis::Client; -use redis::aio::ConnectionManager; -use std::sync::Arc; -use tokio::sync::OnceCell; -use tokio::time::{Duration, timeout}; - -/// Global Redis connection manager -static REDIS_MANAGER: OnceCell> = OnceCell::const_new(); - -/// Global TLS connector for Redis SSL connections -static REDIS_TLS_CONNECTOR: OnceCell> = OnceCell::const_new(); - -/// Centralized Redis connection manager -pub struct RedisManager { - pub connection: ConnectionManager, - pub prefix: String, -} - -impl RedisManager { - /// Initialize the global Redis manager - pub async fn init( - redis_url: &str, - prefix: String, - ssl_config: Option<&crate::core::cli::RedisSslConfig>, - ) -> Result<()> { - log::info!("Initializing Redis manager with URL: {}", redis_url); - - // Add a short connect timeout so startup doesn't block for minutes if Redis is unreachable - let mut url_with_timeout = redis_url.to_string(); - if !url_with_timeout.contains("connect_timeout=") { - if url_with_timeout.contains('?') { - url_with_timeout.push_str("&connect_timeout=10"); - } else { - url_with_timeout.push_str("?connect_timeout=10"); - } - log::info!( - "Redis URL updated with connect_timeout=10s: {}", - url_with_timeout - ); - } - - // If SSL config is provided, ensure URL uses rediss:// protocol - let redis_url = if let Some(_ssl_config) = ssl_config { - if url_with_timeout.starts_with("redis://") - && !url_with_timeout.starts_with("rediss://") - { - let converted_url = url_with_timeout.replacen("redis://", "rediss://", 1); - log::info!( - "SSL config provided, converting URL from redis:// to rediss://: {}", - converted_url - ); - converted_url - } else { - url_with_timeout.to_string() - } - } else { - url_with_timeout.to_string() - }; - - let client = if let Some(ssl_config) = ssl_config { - // Configure Redis client with custom SSL certificates - Self::create_client_with_ssl(&redis_url, ssl_config)? - } else { - // Use default client (will handle rediss:// URLs automatically) - Client::open(redis_url).context("Failed to create Redis client")? - }; - - let connection = timeout(Duration::from_secs(15), client.get_connection_manager()) - .await - .map_err(|_| anyhow::anyhow!("Redis connection manager creation timed out"))? - .context("Failed to create Redis connection manager")?; - - log::info!( - "Redis connection manager created successfully with prefix: {}", - prefix - ); - - // Test the connection - let mut test_conn = connection.clone(); - let ping_result = timeout( - Duration::from_secs(3), - redis::cmd("PING").query_async::(&mut test_conn), - ) - .await; - match ping_result { - Ok(Ok(_)) => log::info!("Redis connection test successful"), - Ok(Err(e)) => { - log::warn!("Redis connection test failed: {}", e); - return Err(anyhow::anyhow!("Redis connection test failed: {}", e)); - } - Err(_) => { - log::warn!("Redis connection test timed out"); - return Err(anyhow::anyhow!("Redis connection test timed out")); - } - } - - let manager = Arc::new(RedisManager { connection, prefix }); - - REDIS_MANAGER - .set(manager) - .map_err(|_| anyhow::anyhow!("Redis manager already initialized"))?; - - Ok(()) - } - - /// Get the global Redis manager instance - pub fn get() -> Result> { - REDIS_MANAGER - .get() - .cloned() - .context("Redis manager not initialized") - } - - /// Get a connection manager for use in other modules - pub fn get_connection(&self) -> ConnectionManager { - self.connection.clone() - } - - /// Get the configured prefix - pub fn get_prefix(&self) -> &str { - &self.prefix - } - - /// Create a namespaced prefix - pub fn create_namespace(&self, namespace: &str) -> String { - format!("{}:{}", self.prefix, namespace) - } - - /// Get the global TLS connector if it was configured - /// This can be used for custom connection handling if needed - pub fn get_tls_connector() -> Option> { - REDIS_TLS_CONNECTOR.get().cloned() - } - - /// Create Redis client with custom SSL/TLS configuration - fn create_client_with_ssl( - redis_url: &str, - ssl_config: &crate::core::cli::RedisSslConfig, - ) -> Result { - use native_tls::{Certificate, Identity, TlsConnector}; - - // Build TLS connector with custom certificates - let mut tls_builder = TlsConnector::builder(); - - // Load CA certificate if provided - if let Some(ca_cert_path) = &ssl_config.ca_cert_path { - let ca_cert_data = std::fs::read(ca_cert_path) - .with_context(|| format!("Failed to read CA certificate from {}", ca_cert_path))?; - let ca_cert = Certificate::from_pem(&ca_cert_data) - .with_context(|| format!("Failed to parse CA certificate from {}", ca_cert_path))?; - tls_builder.add_root_certificate(ca_cert); - log::info!("Redis SSL: Loaded CA certificate from {}", ca_cert_path); - - // Set SSL_CERT_FILE environment variable as a workaround for native-tls/OpenSSL - // This allows the underlying TLS library to use the custom CA certificate - // Note: This affects the current process and child processes - unsafe { - std::env::set_var("SSL_CERT_FILE", ca_cert_path); - } - log::debug!( - "Redis SSL: Set SSL_CERT_FILE environment variable to {}", - ca_cert_path - ); - } - - // Load client certificate and key if provided - if let (Some(client_cert_path), Some(client_key_path)) = - (&ssl_config.client_cert_path, &ssl_config.client_key_path) - { - let client_cert_data = std::fs::read(client_cert_path).with_context(|| { - format!( - "Failed to read client certificate from {}", - client_cert_path - ) - })?; - let client_key_data = std::fs::read(client_key_path) - .with_context(|| format!("Failed to read client key from {}", client_key_path))?; - - // Try to create identity from PEM format (cert + key) - let identity = Identity::from_pkcs8(&client_cert_data, &client_key_data) - .or_else(|_| { - // Try PEM format if PKCS#8 fails - Identity::from_pkcs12(&client_cert_data, "") - }) - .or_else(|_| { - // Try loading as separate PEM files - // Combine cert and key into a single PEM - let mut combined = client_cert_data.clone(); - combined.extend_from_slice(b"\n"); - combined.extend_from_slice(&client_key_data); - Identity::from_pkcs12(&combined, "") - }) - .with_context(|| format!("Failed to parse client certificate/key from {} and {}. Supported formats: PKCS#8, PKCS#12, or PEM", client_cert_path, client_key_path))?; - tls_builder.identity(identity); - log::info!( - "Redis SSL: Loaded client certificate from {} and key from {}", - client_cert_path, - client_key_path - ); - - // Set SSL client certificate environment variables as workaround - // Note: native-tls/OpenSSL may use these for client certificate authentication - unsafe { - std::env::set_var("SSL_CLIENT_CERT", client_cert_path); - std::env::set_var("SSL_CLIENT_KEY", client_key_path); - } - log::debug!("Redis SSL: Set SSL_CLIENT_CERT and SSL_CLIENT_KEY environment variables"); - } - - // Configure certificate verification - if ssl_config.insecure { - tls_builder.danger_accept_invalid_certs(true); - tls_builder.danger_accept_invalid_hostnames(true); - log::warn!("Redis SSL: Certificate verification disabled (insecure mode)"); - } - - // Build the TLS connector with our custom certificate configuration - // This connector will be used by native-tls/OpenSSL for TLS connections - let tls_connector = tls_builder - .build() - .with_context(|| "Failed to build TLS connector")?; - - // Store the TLS connector globally so it can be used by native-tls - // The redis crate with tokio-native-tls-comp uses native-tls internally, - // which will use OpenSSL. OpenSSL respects the SSL_CERT_FILE environment - // variable we set above, and will use the system's default TLS context - // which we've configured through the TlsConnector builder. - let tls_connector_arc = Arc::new(tls_connector); - // Store globally - allow re-initialization in tests by ignoring the error if already set - if REDIS_TLS_CONNECTOR.set(tls_connector_arc.clone()).is_err() { - log::debug!("Redis SSL: TLS connector already initialized, using existing one"); - } else { - log::info!("Redis SSL: TLS connector configured and stored globally"); - } - - // Note: The redis crate (v0.32) with tokio-native-tls-comp uses native-tls internally, - // which in turn uses OpenSSL. While we cannot pass our TlsConnector directly to the - // redis crate, we've configured it properly and set environment variables that - // OpenSSL respects: - // - // 1. SSL_CERT_FILE: Points to our custom CA certificate (if provided) - // 2. SSL_CLIENT_CERT/SSL_CLIENT_KEY: Points to client certificates (if provided) - // 3. The TlsConnector is built and stored, ensuring certificates are valid - // - // OpenSSL will use these environment variables when creating TLS connections, - // which means our custom certificate configuration will be applied. - - let client = Client::open(redis_url) - .with_context(|| "Failed to create Redis client with SSL config")?; - - Ok(client) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::core::cli::RedisSslConfig; - - #[tokio::test] - async fn test_redis_manager_init() { - // This test would require a Redis instance running - // For now, just test that the structure compiles - assert!(true); - } - - #[test] - fn test_create_client_with_ssl_no_config() { - // Test that client creation works without SSL config - let redis_url = "redis://127.0.0.1:6379"; - let result = Client::open(redis_url); - assert!(result.is_ok()); - } - - #[test] - fn test_create_client_with_ssl_insecure() { - // Test SSL config with insecure mode - let ssl_config = RedisSslConfig { - ca_cert_path: None, - client_cert_path: None, - client_key_path: None, - insecure: true, - }; - - let redis_url = "rediss://127.0.0.1:6379"; - let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); - // Should succeed even without certificate files when insecure is true - assert!(result.is_ok()); - } - - #[test] - fn test_create_client_with_ssl_missing_ca_cert() { - // Test that missing CA cert file returns error - let ssl_config = RedisSslConfig { - ca_cert_path: Some("/nonexistent/path/ca.crt".to_string()), - client_cert_path: None, - client_key_path: None, - insecure: false, - }; - - let redis_url = "rediss://127.0.0.1:6379"; - let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); - // Should fail because CA cert file doesn't exist - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("Failed to read CA certificate") - ); - } - - #[test] - fn test_create_client_with_ssl_missing_client_cert() { - // Test that missing client cert file returns error - let ssl_config = RedisSslConfig { - ca_cert_path: None, - client_cert_path: Some("/nonexistent/path/client.crt".to_string()), - client_key_path: Some("/nonexistent/path/client.key".to_string()), - insecure: false, - }; - - let redis_url = "rediss://127.0.0.1:6379"; - let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); - // Should fail because client cert file doesn't exist - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("Failed to read client certificate") - ); - } - - #[test] - fn test_create_client_with_ssl_missing_client_key() { - // Test that missing client key file returns error - let ssl_config = RedisSslConfig { - ca_cert_path: None, - client_cert_path: Some("/nonexistent/path/client.crt".to_string()), - client_key_path: Some("/nonexistent/path/client.key".to_string()), - insecure: false, - }; - - let redis_url = "rediss://127.0.0.1:6379"; - let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); - // Should fail because client key file doesn't exist - assert!(result.is_err()); - } - - #[test] - fn test_create_client_with_ssl_partial_client_config() { - // Test that providing only cert or only key (not both) still validates - let ssl_config = RedisSslConfig { - ca_cert_path: None, - client_cert_path: Some("/nonexistent/path/client.crt".to_string()), - client_key_path: None, // Missing key - insecure: false, - }; - - let redis_url = "rediss://127.0.0.1:6379"; - let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); - // Should succeed because we only validate when both cert and key are provided - assert!(result.is_ok()); - } - - #[test] - fn test_create_client_with_ssl_empty_config() { - // Test SSL config with all None values - let ssl_config = RedisSslConfig { - ca_cert_path: None, - client_cert_path: None, - client_key_path: None, - insecure: false, - }; - - let redis_url = "rediss://127.0.0.1:6379"; - let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); - // Should succeed with empty config (TLS connector builds without custom certs) - assert!(result.is_ok()); - } - - #[test] - fn test_create_client_with_ssl_insecure_builds_connector() { - // Test that insecure mode builds TLS connector successfully - let ssl_config = RedisSslConfig { - ca_cert_path: None, - client_cert_path: None, - client_key_path: None, - insecure: true, - }; - - let redis_url = "rediss://127.0.0.1:6379"; - let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); - // Should succeed - TLS connector builds with insecure settings - assert!(result.is_ok()); - } -} +use anyhow::{Context, Result}; +use redis::Client; +use redis::aio::ConnectionManager; +use std::sync::Arc; +use tokio::sync::OnceCell; +use tokio::time::{Duration, timeout}; + +/// Global Redis connection manager +static REDIS_MANAGER: OnceCell> = OnceCell::const_new(); + +/// Global TLS connector for Redis SSL connections +static REDIS_TLS_CONNECTOR: OnceCell> = OnceCell::const_new(); + +/// Centralized Redis connection manager +pub struct RedisManager { + pub connection: ConnectionManager, + pub prefix: String, +} + +impl RedisManager { + /// Initialize the global Redis manager + pub async fn init( + redis_url: &str, + prefix: String, + ssl_config: Option<&crate::core::cli::RedisSslConfig>, + ) -> Result<()> { + log::info!("Initializing Redis manager with URL: {}", redis_url); + + // Add a short connect timeout so startup doesn't block for minutes if Redis is unreachable + let mut url_with_timeout = redis_url.to_string(); + if !url_with_timeout.contains("connect_timeout=") { + if url_with_timeout.contains('?') { + url_with_timeout.push_str("&connect_timeout=10"); + } else { + url_with_timeout.push_str("?connect_timeout=10"); + } + log::info!( + "Redis URL updated with connect_timeout=10s: {}", + url_with_timeout + ); + } + + // If SSL config is provided, ensure URL uses rediss:// protocol + let redis_url = if let Some(_ssl_config) = ssl_config { + if url_with_timeout.starts_with("redis://") + && !url_with_timeout.starts_with("rediss://") + { + let converted_url = url_with_timeout.replacen("redis://", "rediss://", 1); + log::info!( + "SSL config provided, converting URL from redis:// to rediss://: {}", + converted_url + ); + converted_url + } else { + url_with_timeout.to_string() + } + } else { + url_with_timeout.to_string() + }; + + let client = if let Some(ssl_config) = ssl_config { + // Configure Redis client with custom SSL certificates + Self::create_client_with_ssl(&redis_url, ssl_config)? + } else { + // Use default client (will handle rediss:// URLs automatically) + Client::open(redis_url).context("Failed to create Redis client")? + }; + + let connection = timeout(Duration::from_secs(15), client.get_connection_manager()) + .await + .map_err(|_| anyhow::anyhow!("Redis connection manager creation timed out"))? + .context("Failed to create Redis connection manager")?; + + log::info!( + "Redis connection manager created successfully with prefix: {}", + prefix + ); + + // Test the connection + let mut test_conn = connection.clone(); + let ping_result = timeout( + Duration::from_secs(3), + redis::cmd("PING").query_async::(&mut test_conn), + ) + .await; + match ping_result { + Ok(Ok(_)) => log::info!("Redis connection test successful"), + Ok(Err(e)) => { + log::warn!("Redis connection test failed: {}", e); + return Err(anyhow::anyhow!("Redis connection test failed: {}", e)); + } + Err(_) => { + log::warn!("Redis connection test timed out"); + return Err(anyhow::anyhow!("Redis connection test timed out")); + } + } + + let manager = Arc::new(RedisManager { connection, prefix }); + + REDIS_MANAGER + .set(manager) + .map_err(|_| anyhow::anyhow!("Redis manager already initialized"))?; + + Ok(()) + } + + /// Get the global Redis manager instance + pub fn get() -> Result> { + REDIS_MANAGER + .get() + .cloned() + .context("Redis manager not initialized") + } + + /// Get a connection manager for use in other modules + pub fn get_connection(&self) -> ConnectionManager { + self.connection.clone() + } + + /// Get the configured prefix + pub fn get_prefix(&self) -> &str { + &self.prefix + } + + /// Create a namespaced prefix + pub fn create_namespace(&self, namespace: &str) -> String { + format!("{}:{}", self.prefix, namespace) + } + + /// Get the global TLS connector if it was configured + /// This can be used for custom connection handling if needed + pub fn get_tls_connector() -> Option> { + REDIS_TLS_CONNECTOR.get().cloned() + } + + /// Create Redis client with custom SSL/TLS configuration + fn create_client_with_ssl( + redis_url: &str, + ssl_config: &crate::core::cli::RedisSslConfig, + ) -> Result { + use native_tls::{Certificate, Identity, TlsConnector}; + + // Build TLS connector with custom certificates + let mut tls_builder = TlsConnector::builder(); + + // Load CA certificate if provided + if let Some(ca_cert_path) = &ssl_config.ca_cert_path { + let ca_cert_data = std::fs::read(ca_cert_path) + .with_context(|| format!("Failed to read CA certificate from {}", ca_cert_path))?; + let ca_cert = Certificate::from_pem(&ca_cert_data) + .with_context(|| format!("Failed to parse CA certificate from {}", ca_cert_path))?; + tls_builder.add_root_certificate(ca_cert); + log::info!("Redis SSL: Loaded CA certificate from {}", ca_cert_path); + + // Set SSL_CERT_FILE environment variable as a workaround for native-tls/OpenSSL + // This allows the underlying TLS library to use the custom CA certificate + // Note: This affects the current process and child processes + unsafe { + std::env::set_var("SSL_CERT_FILE", ca_cert_path); + } + log::debug!( + "Redis SSL: Set SSL_CERT_FILE environment variable to {}", + ca_cert_path + ); + } + + // Load client certificate and key if provided + if let (Some(client_cert_path), Some(client_key_path)) = + (&ssl_config.client_cert_path, &ssl_config.client_key_path) + { + let client_cert_data = std::fs::read(client_cert_path).with_context(|| { + format!( + "Failed to read client certificate from {}", + client_cert_path + ) + })?; + let client_key_data = std::fs::read(client_key_path) + .with_context(|| format!("Failed to read client key from {}", client_key_path))?; + + // Try to create identity from PEM format (cert + key) + let identity = Identity::from_pkcs8(&client_cert_data, &client_key_data) + .or_else(|_| { + // Try PEM format if PKCS#8 fails + Identity::from_pkcs12(&client_cert_data, "") + }) + .or_else(|_| { + // Try loading as separate PEM files + // Combine cert and key into a single PEM + let mut combined = client_cert_data.clone(); + combined.extend_from_slice(b"\n"); + combined.extend_from_slice(&client_key_data); + Identity::from_pkcs12(&combined, "") + }) + .with_context(|| format!("Failed to parse client certificate/key from {} and {}. Supported formats: PKCS#8, PKCS#12, or PEM", client_cert_path, client_key_path))?; + tls_builder.identity(identity); + log::info!( + "Redis SSL: Loaded client certificate from {} and key from {}", + client_cert_path, + client_key_path + ); + + // Set SSL client certificate environment variables as workaround + // Note: native-tls/OpenSSL may use these for client certificate authentication + unsafe { + std::env::set_var("SSL_CLIENT_CERT", client_cert_path); + std::env::set_var("SSL_CLIENT_KEY", client_key_path); + } + log::debug!("Redis SSL: Set SSL_CLIENT_CERT and SSL_CLIENT_KEY environment variables"); + } + + // Configure certificate verification + if ssl_config.insecure { + tls_builder.danger_accept_invalid_certs(true); + tls_builder.danger_accept_invalid_hostnames(true); + log::warn!("Redis SSL: Certificate verification disabled (insecure mode)"); + } + + // Build the TLS connector with our custom certificate configuration + // This connector will be used by native-tls/OpenSSL for TLS connections + let tls_connector = tls_builder + .build() + .with_context(|| "Failed to build TLS connector")?; + + // Store the TLS connector globally so it can be used by native-tls + // The redis crate with tokio-native-tls-comp uses native-tls internally, + // which will use OpenSSL. OpenSSL respects the SSL_CERT_FILE environment + // variable we set above, and will use the system's default TLS context + // which we've configured through the TlsConnector builder. + let tls_connector_arc = Arc::new(tls_connector); + // Store globally - allow re-initialization in tests by ignoring the error if already set + if REDIS_TLS_CONNECTOR.set(tls_connector_arc.clone()).is_err() { + log::debug!("Redis SSL: TLS connector already initialized, using existing one"); + } else { + log::info!("Redis SSL: TLS connector configured and stored globally"); + } + + // Note: The redis crate (v0.32) with tokio-native-tls-comp uses native-tls internally, + // which in turn uses OpenSSL. While we cannot pass our TlsConnector directly to the + // redis crate, we've configured it properly and set environment variables that + // OpenSSL respects: + // + // 1. SSL_CERT_FILE: Points to our custom CA certificate (if provided) + // 2. SSL_CLIENT_CERT/SSL_CLIENT_KEY: Points to client certificates (if provided) + // 3. The TlsConnector is built and stored, ensuring certificates are valid + // + // OpenSSL will use these environment variables when creating TLS connections, + // which means our custom certificate configuration will be applied. + + let client = Client::open(redis_url) + .with_context(|| "Failed to create Redis client with SSL config")?; + + Ok(client) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::cli::RedisSslConfig; + + #[tokio::test] + async fn test_redis_manager_init() { + // This test would require a Redis instance running + // For now, just test that the structure compiles + assert!(true); + } + + #[test] + fn test_create_client_with_ssl_no_config() { + // Test that client creation works without SSL config + let redis_url = "redis://127.0.0.1:6379"; + let result = Client::open(redis_url); + assert!(result.is_ok()); + } + + #[test] + fn test_create_client_with_ssl_insecure() { + // Test SSL config with insecure mode + let ssl_config = RedisSslConfig { + ca_cert_path: None, + client_cert_path: None, + client_key_path: None, + insecure: true, + }; + + let redis_url = "rediss://127.0.0.1:6379"; + let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); + // Should succeed even without certificate files when insecure is true + assert!(result.is_ok()); + } + + #[test] + fn test_create_client_with_ssl_missing_ca_cert() { + // Test that missing CA cert file returns error + let ssl_config = RedisSslConfig { + ca_cert_path: Some("/nonexistent/path/ca.crt".to_string()), + client_cert_path: None, + client_key_path: None, + insecure: false, + }; + + let redis_url = "rediss://127.0.0.1:6379"; + let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); + // Should fail because CA cert file doesn't exist + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Failed to read CA certificate") + ); + } + + #[test] + fn test_create_client_with_ssl_missing_client_cert() { + // Test that missing client cert file returns error + let ssl_config = RedisSslConfig { + ca_cert_path: None, + client_cert_path: Some("/nonexistent/path/client.crt".to_string()), + client_key_path: Some("/nonexistent/path/client.key".to_string()), + insecure: false, + }; + + let redis_url = "rediss://127.0.0.1:6379"; + let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); + // Should fail because client cert file doesn't exist + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Failed to read client certificate") + ); + } + + #[test] + fn test_create_client_with_ssl_missing_client_key() { + // Test that missing client key file returns error + let ssl_config = RedisSslConfig { + ca_cert_path: None, + client_cert_path: Some("/nonexistent/path/client.crt".to_string()), + client_key_path: Some("/nonexistent/path/client.key".to_string()), + insecure: false, + }; + + let redis_url = "rediss://127.0.0.1:6379"; + let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); + // Should fail because client key file doesn't exist + assert!(result.is_err()); + } + + #[test] + fn test_create_client_with_ssl_partial_client_config() { + // Test that providing only cert or only key (not both) still validates + let ssl_config = RedisSslConfig { + ca_cert_path: None, + client_cert_path: Some("/nonexistent/path/client.crt".to_string()), + client_key_path: None, // Missing key + insecure: false, + }; + + let redis_url = "rediss://127.0.0.1:6379"; + let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); + // Should succeed because we only validate when both cert and key are provided + assert!(result.is_ok()); + } + + #[test] + fn test_create_client_with_ssl_empty_config() { + // Test SSL config with all None values + let ssl_config = RedisSslConfig { + ca_cert_path: None, + client_cert_path: None, + client_key_path: None, + insecure: false, + }; + + let redis_url = "rediss://127.0.0.1:6379"; + let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); + // Should succeed with empty config (TLS connector builds without custom certs) + assert!(result.is_ok()); + } + + #[test] + fn test_create_client_with_ssl_insecure_builds_connector() { + // Test that insecure mode builds TLS connector successfully + let ssl_config = RedisSslConfig { + ca_cert_path: None, + client_cert_path: None, + client_key_path: None, + insecure: true, + }; + + let redis_url = "rediss://127.0.0.1:6379"; + let result = RedisManager::create_client_with_ssl(redis_url, &ssl_config); + // Should succeed - TLS connector builds with insecure settings + assert!(result.is_ok()); + } +} diff --git a/src/utils/bpf_utils.rs b/src/utils/bpf_utils.rs index 8c5d6c8..14cff89 100644 --- a/src/utils/bpf_utils.rs +++ b/src/utils/bpf_utils.rs @@ -2,7 +2,7 @@ use std::fs; use std::net::{Ipv4Addr, Ipv6Addr}; use std::os::fd::AsFd; -use crate::security::firewall::bpf::{self, FilterSkel}; +use crate::security::firewall::bpf::{self, XdpSkel}; use libbpf_rs::{Xdp, XdpFlags}; use nix::libc; @@ -76,13 +76,13 @@ fn try_enable_ipv6_for_interface(iface: &str) -> Result<(), Box, + skel: &mut XdpSkel<'_>, ifindex: i32, iface_name: Option<&str>, ip_version: &str, ) -> Result> { // Try hardware mode first, fall back to driver mode if not supported - let xdp = Xdp::new(skel.progs.arxignis_xdp_filter.as_fd().into()); + let xdp = Xdp::new(skel.progs.xdp_pipeline.as_fd().into()); // Try hardware offload mode first if let Ok(()) = xdp.attach(ifindex, XdpFlags::HW_MODE) { @@ -232,6 +232,31 @@ pub fn convert_ipv6_into_bpf_map_key_bytes(ip: Ipv6Addr, prefixlen: u32) -> Box< my_ip_key_bytes.to_vec().into_boxed_slice() } +pub fn convert_ip_port_into_bpf_map_key_bytes(ip: Ipv4Addr, port: u16) -> Box<[u8]> { + let ip_u32: u32 = ip.into(); + let ip_be = ip_u32.to_be(); + + let ip_port_key: bpf::types::src_port_key_v4 = bpf::types::src_port_key_v4 { + addr: ip_be, + port: port.to_be(), + }; + + let ip_port_key_bytes = unsafe { plain::as_bytes(&ip_port_key) }; + ip_port_key_bytes.to_vec().into_boxed_slice() +} + +pub fn convert_ipv6_port_into_map_key_bytes(ip: Ipv6Addr, port: u16) -> Box<[u8]> { + let ip_bytes = ip.octets(); + + let ip_port_key: bpf::types::src_port_key_v6 = bpf::types::src_port_key_v6 { + addr: ip_bytes, + port: port.to_be(), + }; + + let ip_port_key_bytes = unsafe { plain::as_bytes(&ip_port_key) }; + ip_port_key_bytes.to_vec().into_boxed_slice() +} + pub fn bpf_detach_from_xdp(ifindex: i32) -> Result<(), Box> { // Create a dummy XDP instance for detaching // We need to query first to get the existing program ID diff --git a/src/utils/fingerprint/tcp_fingerprint.rs b/src/utils/fingerprint/tcp_fingerprint.rs index 2252cbd..131f7f0 100644 --- a/src/utils/fingerprint/tcp_fingerprint.rs +++ b/src/utils/fingerprint/tcp_fingerprint.rs @@ -3,9 +3,9 @@ use chrono::{DateTime, Utc}; use libbpf_rs::MapCore; use serde::{Deserialize, Serialize}; use std::net::Ipv4Addr; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; -use crate::security::firewall::bpf::FilterSkel; +use crate::security::firewall::bpf::XdpSkel; /// TCP fingerprinting configuration #[derive(Debug, Clone, Serialize, Deserialize)] @@ -246,14 +246,14 @@ pub fn get_global_tcp_fingerprint_collector() -> Option>>, + skels: Vec>>>, enabled: bool, config: TcpFingerprintConfig, } impl TcpFingerprintCollector { /// Create a new TCP fingerprint collector - pub fn new(skels: Vec>>, enabled: bool) -> Self { + pub fn new(skels: Vec>>>, enabled: bool) -> Self { Self { skels, enabled, @@ -263,7 +263,7 @@ impl TcpFingerprintCollector { /// Create a new TCP fingerprint collector with configuration pub fn new_with_config( - skels: Vec>>, + skels: Vec>>>, config: TcpFingerprintConfig, ) -> Self { Self { @@ -300,7 +300,18 @@ impl TcpFingerprintCollector { // Try to find fingerprint in any skeleton's IPv4 map for skel in &self.skels { - if let Ok(iter) = skel.maps.tcp_fingerprints.lookup_batch( + let skel_guard = match skel.lock() { + Ok(guard) => guard, + Err(e) => { + log::warn!( + "Failed to lock XDP skeleton for IPv4 fingerprint lookup: {}", + e + ); + continue; + } + }; + + if let Ok(iter) = skel_guard.maps.tcp_fingerprints.lookup_batch( 1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY, @@ -391,7 +402,18 @@ impl TcpFingerprintCollector { // Try to find fingerprint in any skeleton's IPv6 map for skel in &self.skels { - if let Ok(iter) = skel.maps.tcp_fingerprints_v6.lookup_batch( + let skel_guard = match skel.lock() { + Ok(guard) => guard, + Err(e) => { + log::warn!( + "Failed to lock XDP skeleton for IPv6 fingerprint lookup: {}", + e + ); + continue; + } + }; + + if let Ok(iter) = skel_guard.maps.tcp_fingerprints_v6.lookup_batch( 1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY, @@ -486,7 +508,19 @@ impl TcpFingerprintCollector { let mut stats = Vec::new(); for (i, skel) in self.skels.iter().enumerate() { log::debug!("Collecting TCP fingerprint stats from skeleton {}", i); - match self.collect_fingerprint_stats_from_skeleton(skel) { + let skel_guard = match skel.lock() { + Ok(guard) => guard, + Err(e) => { + log::warn!( + "Failed to lock XDP skeleton {} for fingerprint stats: {}", + i, + e + ); + continue; + } + }; + + match self.collect_fingerprint_stats_from_skeleton(&*skel_guard) { Ok(stat) => { log::debug!( "Skeleton {} collected {} fingerprints", @@ -511,7 +545,7 @@ impl TcpFingerprintCollector { /// Collect TCP fingerprint statistics from a single BPF skeleton fn collect_fingerprint_stats_from_skeleton( &self, - skel: &FilterSkel, + skel: &XdpSkel, ) -> Result> { if !self.enabled { return Ok(TcpFingerprintStats { @@ -625,10 +659,7 @@ impl TcpFingerprintCollector { } /// Collect TCP SYN statistics - fn collect_syn_stats( - &self, - skel: &FilterSkel, - ) -> Result> { + fn collect_syn_stats(&self, skel: &XdpSkel) -> Result> { let key = 0u32.to_le_bytes(); let stats_bytes = skel .maps @@ -675,7 +706,7 @@ impl TcpFingerprintCollector { /// Collect TCP fingerprints from BPF map fn collect_tcp_fingerprints( &self, - skel: &FilterSkel, + skel: &XdpSkel, fingerprints: &mut Vec, ) -> Result<(), Box> { log::debug!("Collecting TCP fingerprints from BPF map (IPv4)"); @@ -1151,7 +1182,19 @@ impl TcpFingerprintCollector { for skel in &self.skels { let mut fingerprints = Vec::new(); - self.collect_tcp_fingerprints(skel, &mut fingerprints)?; + let skel_guard = match skel.lock() { + Ok(guard) => guard, + Err(e) => { + log::warn!( + "Failed to lock XDP skeleton for fingerprint event collection: {}", + e + ); + continue; + } + }; + + self.collect_tcp_fingerprints(&*skel_guard, &mut fingerprints)?; + drop(skel_guard); // Convert to events for entry in fingerprints { @@ -1228,8 +1271,19 @@ impl TcpFingerprintCollector { log::debug!("Resetting TCP fingerprint counters"); for skel in &self.skels { + let skel_guard = match skel.lock() { + Ok(guard) => guard, + Err(e) => { + log::warn!( + "Failed to lock XDP skeleton for resetting fingerprint counters: {}", + e + ); + continue; + } + }; + // Reset TCP fingerprints map - match skel.maps.tcp_fingerprints.lookup_batch( + match skel_guard.maps.tcp_fingerprints.lookup_batch( 1000, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY, @@ -1240,7 +1294,7 @@ impl TcpFingerprintCollector { if key_bytes.len() >= 20 { // Create zero value for fingerprint data (48 bytes with padding) let zero_value = vec![0u8; 48]; - if let Err(e) = skel.maps.tcp_fingerprints.update( + if let Err(e) = skel_guard.maps.tcp_fingerprints.update( &key_bytes, &zero_value, libbpf_rs::MapFlags::ANY, @@ -1262,7 +1316,8 @@ impl TcpFingerprintCollector { let key = 0u32.to_le_bytes(); let zero_stats = vec![0u8; 24]; // 3 * u64 = 24 bytes if let Err(e) = - skel.maps + skel_guard + .maps .tcp_syn_stats .update(&key, &zero_stats, libbpf_rs::MapFlags::ANY) { @@ -1284,8 +1339,20 @@ impl TcpFingerprintCollector { for (i, skel) in self.skels.iter().enumerate() { log::debug!("Checking accessibility of BPF maps for skeleton {}", i); + let skel_guard = match skel.lock() { + Ok(guard) => guard, + Err(e) => { + log::warn!( + "Failed to lock XDP skeleton {} while checking maps: {}", + i, + e + ); + continue; + } + }; + // Check tcp_fingerprints map - match skel.maps.tcp_fingerprints.lookup_batch( + match skel_guard.maps.tcp_fingerprints.lookup_batch( 1, libbpf_rs::MapFlags::ANY, libbpf_rs::MapFlags::ANY, @@ -1300,7 +1367,7 @@ impl TcpFingerprintCollector { // Check tcp_syn_stats map let key = 0u32.to_le_bytes(); - match skel + match skel_guard .maps .tcp_syn_stats .lookup(&key, libbpf_rs::MapFlags::ANY) diff --git a/src/worker/config.rs b/src/worker/config.rs index a30c391..7f115b7 100644 --- a/src/worker/config.rs +++ b/src/worker/config.rs @@ -6,7 +6,7 @@ use hyper::StatusCode; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::io::Read; -use std::sync::{Arc, OnceLock, RwLock}; +use std::sync::{Arc, Mutex, OnceLock, RwLock}; use tokio::sync::watch; use tokio::time::{Duration, MissedTickBehavior, interval}; @@ -29,6 +29,12 @@ pub struct Config { pub last_modified: String, } +#[derive(Debug, Clone, Deserialize, Serialize, Default)] +pub struct AccessRuleConfig { + #[serde(default, rename = "rateLimit")] + pub rate_limit: XDPRateLimitConfig, +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct AccessRule { pub id: String, @@ -36,6 +42,8 @@ pub struct AccessRule { pub description: String, pub allow: RuleSet, pub block: RuleSet, + #[serde(default)] + pub config: AccessRuleConfig, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -55,8 +63,66 @@ pub struct WafRule { pub config: Option, } -// Re-export RateLimitConfig from actions module -pub use crate::security::waf::actions::rate_limit::RateLimitConfig; +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct RateLimitConfig { + pub period: String, + pub duration: String, + pub requests: String, +} + +impl RateLimitConfig { + pub fn from_json(value: &serde_json::Value) -> Result { + // Parse from nested structure: {"rateLimit": {"period": "25", ...}} + if let Some(rate_limit_obj) = value.get("rateLimit") { + serde_json::from_value(rate_limit_obj.clone()).map_err(|e| e.to_string()) + } else { + Err("rateLimit field not found".to_string()) + } + } + + pub fn period_secs(&self) -> u64 { + self.period.parse().unwrap_or(60) + } + + pub fn duration_secs(&self) -> u64 { + self.duration.parse().unwrap_or(60) + } + + pub fn requests_count(&self) -> usize { + self.requests.parse().unwrap_or(100) + } +} + +/// These are the fields which can be set in runtime +/// We can expose these to the dashboard +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct XDPRateLimitConfig { + pub enabled: bool, + pub requests: u64, + pub burst_factor: f32, +} + +// TODO: ask arpad to rename the field from config +impl XDPRateLimitConfig { + pub fn from_json(value: &serde_json::Value) -> Result { + // Parse from nested structure: {"rateLimit": {"period": "25", ...}} + if let Some(rate_limit_obj) = value.get("rateLimit") { + serde_json::from_value(rate_limit_obj.clone()).map_err(|e| e.to_string()) + } else { + Err("rateLimit field not found".to_string()) + } + } +} + +impl Default for XDPRateLimitConfig { + fn default() -> Self { + Self { + enabled: true, + requests: 1000, + burst_factor: 1.5, + } + } +} #[derive(Debug, Clone, Deserialize, Serialize)] pub struct RuleSet { @@ -248,11 +314,11 @@ pub struct ConfigWorker { base_url: String, api_key: String, refresh_interval_secs: u64, - skels: Vec>>, + skels: Vec>>>, security_rules_config_path: std::path::PathBuf, is_agent_mode: bool, - nftables_firewall: Option>>, - iptables_firewall: Option>>, + nftables_firewall: Option>>, + iptables_firewall: Option>>, } impl ConfigWorker { @@ -260,7 +326,7 @@ impl ConfigWorker { base_url: String, api_key: String, refresh_interval_secs: u64, - skels: Vec>>, + skels: Vec>>>, security_rules_config_path: std::path::PathBuf, ) -> Self { Self {