diff --git a/benches/module_benchmark.rs b/benches/module_benchmark.rs index 415b5449..1ee8b48c 100644 --- a/benches/module_benchmark.rs +++ b/benches/module_benchmark.rs @@ -39,8 +39,8 @@ use std::collections::HashMap; use std::sync::Arc; use rustible::modules::{ - Diff, ModuleClassification, ModuleContext, ModuleOutput, ModuleParams, ModuleRegistry, - ParallelizationHint, ParamExt, + validate_command_args, Diff, ModuleClassification, ModuleContext, ModuleOutput, ModuleParams, + ModuleRegistry, ParallelizationHint, ParamExt, }; // ============================================================================ @@ -150,6 +150,30 @@ fn generate_context_with_vars(num_vars: usize) -> ModuleContext { // PARAMETER PARSING BENCHMARKS // ============================================================================ +fn bench_validate_command_args(c: &mut Criterion) { + let mut group = c.benchmark_group("validate_command_args"); + + // Fast path: safe alphanumeric + let safe_simple = "nginx -t"; + group.bench_function("safe_simple", |b| { + b.iter(|| validate_command_args(black_box(safe_simple))) + }); + + // Slow path: safe but quoted (this is what we are optimizing) + let safe_quoted = "echo \"hello world\""; + group.bench_function("safe_quoted", |b| { + b.iter(|| validate_command_args(black_box(safe_quoted))) + }); + + // Error path: dangerous + let dangerous = "sh -c 'echo pwned' #"; + group.bench_function("dangerous", |b| { + b.iter(|| validate_command_args(black_box(dangerous))) + }); + + group.finish(); +} + fn bench_parameter_parsing(c: &mut Criterion) { let mut group = c.benchmark_group("parameter_parsing"); @@ -1079,6 +1103,7 @@ fn bench_ansible_comparison_baseline(c: &mut Criterion) { criterion_group!( parameter_benches, + bench_validate_command_args, bench_parameter_parsing, bench_parameter_validation_overhead, ); diff --git a/src/modules/mod.rs b/src/modules/mod.rs index b73cf0bd..f932d92c 100644 --- a/src/modules/mod.rs +++ b/src/modules/mod.rs @@ -389,18 +389,15 @@ pub fn validate_command_args(args: &str) -> ModuleResult<()> { )); } - // Fast path: scan for characters that are known to be safe. - // If the string contains only safe characters, we can skip the detailed check. - // This avoids checking 24 patterns for every safe string (O(N) vs O(M*N)). - // - // Safe characters: alphanumeric, space, _, -, ., /, :, +, =, ,, @, % - let is_safe = args.bytes().all(|b| matches!(b, - b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | - b' ' | b'_' | b'-' | b'.' | b'/' | b':' | - b'+' | b'=' | b',' | b'@' | b'%' - )); - - if is_safe { + // Optimized check: scan for any dangerous character in a single pass. + // This replaces the previous fast-path (which only allowed alphanumeric/safe chars) + // and the O(24*N) loop for quoted strings. + const DANGEROUS_CHARS: &[char] = &[ + '$', '`', '&', '|', ';', '<', '>', '\n', '\r', '{', '}', '(', ')', '[', ']', '*', '?', + '!', '\\', '#', + ]; + + if args.find(DANGEROUS_CHARS).is_none() { return Ok(()); }