diff --git a/pingpanda.py b/pingpanda.py index d2d8905..877f868 100644 --- a/pingpanda.py +++ b/pingpanda.py @@ -43,10 +43,15 @@ def load_config(args: argparse.Namespace) -> Dict[str, str]: def main() -> None: + import asyncio args = parse_args() config = load_config(args) monitor = PingPandaApp(config) - monitor.run() + try: + asyncio.run(monitor.run()) + except KeyboardInterrupt: + # Graceful shutdown on Ctrl+C - cleanup is handled in app.run() + pass PingPanda = PingPandaApp diff --git a/pingpanda_core/app.py b/pingpanda_core/app.py index ecbd437..7be5222 100644 --- a/pingpanda_core/app.py +++ b/pingpanda_core/app.py @@ -5,12 +5,15 @@ import logging import os import time +import asyncio from datetime import datetime from importlib import import_module from logging.handlers import RotatingFileHandler -from concurrent.futures import Future, ThreadPoolExecutor from typing import Any, Callable, Dict, List, Optional, Set, Union +import aiohttp +import aiodns + from .checks import CheckDependencies, DNSCheck, PingCheck, SSLCheck, WebsiteCheck from .notifications import NotificationManager, NotificationSettings from .persistence import PersistenceManager, StatsPersistenceSettings @@ -76,12 +79,15 @@ def __init__(self, config: Optional[Dict[str, Any]] = None): self._setup_logging() self._load_config() self._setup_prometheus() - self._thread_pool: Optional[ThreadPoolExecutor] = None self._initialize_components() self.logger.info( "PingPanda initialized at %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"), ) + + # Async resources + self.http_session: Optional[aiohttp.ClientSession] = None + self.dns_resolver: Optional[aiodns.DNSResolver] = None def _setup_logging(self) -> None: log_level = getattr(logging, str(self.config.get("log_level", "INFO")).upper(), logging.INFO) @@ -328,10 +334,9 @@ def _initialize_components(self) -> None: self._ssl_check = SSLCheck(self._check_deps) self._build_check_jobs() - self._setup_thread_pool() def _build_check_jobs(self) -> None: - self._check_jobs: List[Callable[[], None]] = [] + self._check_jobs: List[Callable[[], Any]] = [] if self.enable_dns: self._check_jobs.append(self._dns_check.run) if self.enable_ping: @@ -341,17 +346,14 @@ def _build_check_jobs(self) -> None: if self.enable_ssl_check: self._check_jobs.append(self._ssl_check.run) - def _setup_thread_pool(self) -> None: - worker_count = len(getattr(self, "_check_jobs", [])) - if worker_count == 0: - self._thread_pool = None - return - - self._thread_pool = ThreadPoolExecutor(max_workers=worker_count, thread_name_prefix="pingpanda-check") - self.logger.debug("Thread pool initialized with %s worker(s).", worker_count) - - def send_notification(self, message: str, status: str, check_type: str, target: str) -> None: - self.notifier.notify(message, status=status, check_type=check_type, target=target) + async def send_notification(self, message: str, status: str, check_type: str, target: str) -> None: + await self.notifier.notify( + message, + status=status, + check_type=check_type, + target=target, + session=self.http_session, + ) def _log_filter_notice(self, key: str, message: str, level: int = logging.INFO) -> None: if key not in self._filter_log_tracker: @@ -447,7 +449,7 @@ def output_status_summary(self) -> None: self.logger.info("===============================") - def run(self) -> None: + async def run(self) -> None: self.output_status_summary() if self.enable_advanced_stats: @@ -463,33 +465,37 @@ def run(self) -> None: self.flap_window_seconds, ) + # Initialize async resources + self.http_session = aiohttp.ClientSession() + self.dns_resolver = aiodns.DNSResolver() + try: while True: loop_start = time.time() self._filter_log_tracker.clear() - if self._thread_pool and self._check_jobs: - futures: List[Future[None]] = [self._thread_pool.submit(job) for job in self._check_jobs] - for future in futures: - future.result() - else: - for job in self._check_jobs: - job() + if self._check_jobs: + # Run all check jobs concurrently + tasks = [job() for job in self._check_jobs] + await asyncio.gather(*tasks) self._maybe_output_summary() elapsed = time.time() - loop_start remaining = max(0.0, self.interval - elapsed) - time.sleep(remaining) + await asyncio.sleep(remaining) + except asyncio.CancelledError: + self.logger.info("Shutting down...") except KeyboardInterrupt: self.logger.info("Shutting down gracefully...") finally: - self._cleanup() + await self._cleanup() - def _cleanup(self) -> None: - if self._thread_pool: - self._thread_pool.shutdown(wait=True) - self._thread_pool = None + async def _cleanup(self) -> None: + if self.http_session: + await self.http_session.close() + self.http_session = None + if self.stats_manager: self.stats_manager.save() diff --git a/pingpanda_core/checks.py b/pingpanda_core/checks.py index 99dfbf0..c914c89 100644 --- a/pingpanda_core/checks.py +++ b/pingpanda_core/checks.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import socket import ssl import time @@ -9,8 +10,9 @@ from datetime import datetime from typing import Any, Dict, Optional -import pythonping -import requests +import aiohttp +import aiodns +import aioping from .stats import StatsManager, StatsUpdateResult @@ -29,7 +31,7 @@ def __init__(self, deps: CheckDependencies): def app(self): return self.deps.app - def run(self) -> None: + async def run(self) -> None: app = self.app if not app.enable_dns: return @@ -37,52 +39,66 @@ def run(self) -> None: if not (app.show_only_success or app.show_only_failure): app.logger.info("Starting DNS resolution checks...") + tasks = [] for domain in app.domains: - # Check if we should skip this target due to backoff/circuit breaker - if not app.failure_tracker.should_check(f"dns:{domain}"): - if app.verbose: - app.logger.debug("Skipping DNS check for %s (in backoff/circuit open)", domain) - continue + tasks.append(self._check_domain(domain)) + + await asyncio.gather(*tasks) - start_time = time.perf_counter() - success = False - - for attempt in range(app.retry_count): - try: - socket.gethostbyname(domain) - elapsed = time.perf_counter() - start_time - duration_ms = elapsed * 1000 + async def _check_domain(self, domain: str) -> None: + app = self.app + # Check if we should skip this target due to backoff/circuit breaker + if not app.failure_tracker.should_check(f"dns:{domain}"): + if app.verbose: + app.logger.debug("Skipping DNS check for %s (in backoff/circuit open)", domain) + return - if app._should_log_result(True): - app.logger.info( - "DNS Resolution for %s: PASS (Time: %.2fms)", - domain, - duration_ms, - ) + start_time = time.perf_counter() + success = False - if app.enable_prometheus: - app.dns_status.labels(domain=domain).set(1) - app.dns_response_time.labels(domain=domain).observe(elapsed) + for attempt in range(app.retry_count): + try: + # Use aiodns for async resolution + # Assuming app has a resolver instance or we create one + resolver = getattr(app, "dns_resolver", None) + if not resolver: + # Fallback if not initialized in app (though it should be) + resolver = aiodns.DNSResolver() + + await resolver.query(domain, 'A') + + elapsed = time.perf_counter() - start_time + duration_ms = elapsed * 1000 - app.send_notification( - f"DNS resolution successful in {duration_ms:.2f}ms", - status="ok", - check_type="DNS", - target=domain, + if app._should_log_result(True): + app.logger.info( + "DNS Resolution for %s: PASS (Time: %.2fms)", + domain, + duration_ms, ) - success = True - break - except socket.gaierror as exc: - if app.verbose: - app.logger.debug("DNS Resolution attempt %s for %s failed: %s", attempt + 1, domain, exc) - time.sleep(1) - # Record the result in the failure tracker - app.failure_tracker.record_result(f"dns:{domain}", success) + if app.enable_prometheus: + app.dns_status.labels(domain=domain).set(1) + app.dns_response_time.labels(domain=domain).observe(elapsed) - if success: - continue + await app.send_notification( + f"DNS resolution successful in {duration_ms:.2f}ms", + status="ok", + check_type="DNS", + target=domain, + ) + success = True + break + except (aiodns.error.DNSError, Exception) as exc: + if app.verbose: + app.logger.debug("DNS Resolution attempt %s for %s failed: %s", attempt + 1, domain, exc) + if attempt < app.retry_count - 1: + await asyncio.sleep(1) + # Record the result in the failure tracker + app.failure_tracker.record_result(f"dns:{domain}", success) + + if not success: if app._should_log_result(False): app.logger.error("DNS Resolution for %s: FAIL", domain) @@ -90,7 +106,7 @@ def run(self) -> None: app.dns_status.labels(domain=domain).set(0) app.dns_errors.labels(domain=domain).inc() - app.send_notification( + await app.send_notification( f"Failed to resolve domain after {app.retry_count} attempts", status="error", check_type="DNS", @@ -110,7 +126,7 @@ def app(self): def stats(self) -> Optional[StatsManager]: return self.deps.stats - def run(self) -> None: + async def run(self) -> None: app = self.app if not app.enable_ping: return @@ -118,56 +134,56 @@ def run(self) -> None: if not (app.show_only_success or app.show_only_failure): app.logger.info("Starting ping checks...") + tasks = [] for ip in app.ping_ips: - # Check if we should skip this target due to backoff/circuit breaker - if not app.failure_tracker.should_check(f"ping:{ip}"): - if app.verbose: - app.logger.debug("Skipping ping check for %s (in backoff/circuit open)", ip) - continue - - success = False - start_time = time.perf_counter() + tasks.append(self._check_ip(ip)) + + await asyncio.gather(*tasks) - for attempt in range(app.retry_count): - try: - response_list = pythonping.ping(ip, count=1, timeout=2) - if not response_list.success(): - if app.verbose: - app.logger.debug("Ping attempt %s to %s failed", attempt + 1, ip) - time.sleep(1) - continue + async def _check_ip(self, ip: str) -> None: + app = self.app + # Check if we should skip this target due to backoff/circuit breaker + if not app.failure_tracker.should_check(f"ping:{ip}"): + if app.verbose: + app.logger.debug("Skipping ping check for %s (in backoff/circuit open)", ip) + return - elapsed = time.perf_counter() - start_time - duration_ms = response_list.rtt_avg_ms + success = False - if app._should_log_result(True): - app.logger.info("Ping to %s: PASS (Time: %.2fms)", ip, duration_ms) + for attempt in range(app.retry_count): + try: + # aioping returns delay in seconds + delay = await aioping.ping(ip, timeout=2) + + duration_ms = delay * 1000 - if app.enable_prometheus: - app.ping_status.labels(target=ip).set(1) - app.ping_response_time.labels(target=ip).observe(elapsed) + if app._should_log_result(True): + app.logger.info("Ping to %s: PASS (Time: %.2fms)", ip, duration_ms) - self._update_stats(ip, True) + if app.enable_prometheus: + app.ping_status.labels(target=ip).set(1) + app.ping_response_time.labels(target=ip).observe(delay) - app.send_notification( - f"Ping successful in {duration_ms:.2f}ms", - status="ok", - check_type="Ping", - target=ip, - ) - success = True - break - except Exception as exc: # pylint: disable=broad-except - if app.verbose: - app.logger.debug("Ping attempt %s to %s failed: %s", attempt + 1, ip, exc) - time.sleep(1) + await self._update_stats(ip, True) - # Record the result in the failure tracker - app.failure_tracker.record_result(f"ping:{ip}", success) + await app.send_notification( + f"Ping successful in {duration_ms:.2f}ms", + status="ok", + check_type="Ping", + target=ip, + ) + success = True + break + except Exception as exc: + if app.verbose: + app.logger.debug("Ping attempt %s to %s failed: %s", attempt + 1, ip, exc) + if attempt < app.retry_count - 1: + await asyncio.sleep(1) - if success: - continue + # Record the result in the failure tracker + app.failure_tracker.record_result(f"ping:{ip}", success) + if not success: if app._should_log_result(False): app.logger.error("Ping to %s: FAIL", ip) @@ -175,23 +191,23 @@ def run(self) -> None: app.ping_status.labels(target=ip).set(0) app.ping_errors.labels(target=ip).inc() - self._update_stats(ip, False) + await self._update_stats(ip, False) - app.send_notification( + await app.send_notification( f"Failed to ping host after {app.retry_count} attempts", status="error", check_type="Ping", target=ip, ) - def _update_stats(self, ip: str, success: bool) -> None: + async def _update_stats(self, ip: str, success: bool) -> None: if not self.stats: return result: StatsUpdateResult = self.stats.update_ip(ip, success) if result.flapping_changed and result.is_flapping: - self.app.send_notification( + await self.app.send_notification( f"IP {ip} is flapping (>{self.app.flap_threshold} status changes in {self.app.flap_window_seconds}s)", status="error", check_type="Flapping", @@ -213,7 +229,7 @@ def __init__(self, deps: CheckDependencies): def app(self): return self.deps.app - def run(self) -> None: + async def run(self) -> None: app = self.app if not app.enable_website_check or not app.websites: return @@ -221,62 +237,78 @@ def run(self) -> None: if not (app.show_only_success or app.show_only_failure): app.logger.info("Starting website checks...") + tasks = [] for website in app.websites: if not website: continue + tasks.append(self._check_website(website)) + + await asyncio.gather(*tasks) - # Check if we should skip this target due to backoff/circuit breaker - if not app.failure_tracker.should_check(f"website:{website}"): - if app.verbose: - app.logger.debug("Skipping website check for %s (in backoff/circuit open)", website) - continue + async def _check_website(self, website: str) -> None: + app = self.app + # Check if we should skip this target due to backoff/circuit breaker + if not app.failure_tracker.should_check(f"website:{website}"): + if app.verbose: + app.logger.debug("Skipping website check for %s (in backoff/circuit open)", website) + return - start_time = time.perf_counter() - success = False + start_time = time.perf_counter() + success = False + + # Use existing session if available + session = getattr(app, "http_session", None) + local_session = False + if not session: + session = aiohttp.ClientSession() + local_session = True + + try: try: - response = requests.get(website, timeout=10) - elapsed = time.perf_counter() - start_time - duration_ms = elapsed * 1000 - - if response.status_code in app.success_http_codes: - success = True - if app._should_log_result(True): - app.logger.info( - "Website check for %s: PASS (HTTP Status: %s, Time: %.2fms)", - website, - response.status_code, - duration_ms, + async with session.get(website, timeout=aiohttp.ClientTimeout(total=10)) as response: + elapsed = time.perf_counter() - start_time + duration_ms = elapsed * 1000 + status_code = response.status + + if status_code in app.success_http_codes: + success = True + if app._should_log_result(True): + app.logger.info( + "Website check for %s: PASS (HTTP Status: %s, Time: %.2fms)", + website, + status_code, + duration_ms, + ) + + if app.enable_prometheus: + app.website_status.labels(url=website).set(1) + app.website_response_time.labels(url=website).observe(elapsed) + + await app.send_notification( + f"Website check successful (HTTP {status_code}, {duration_ms:.2f}ms)", + status="ok", + check_type="Website", + target=website, ) - - if app.enable_prometheus: - app.website_status.labels(url=website).set(1) - app.website_response_time.labels(url=website).observe(elapsed) - - app.send_notification( - f"Website check successful (HTTP {response.status_code}, {duration_ms:.2f}ms)", - status="ok", - check_type="Website", - target=website, - ) - else: - if app._should_log_result(False): - app.logger.error( - "Website check for %s: FAIL (HTTP Status: %s)", - website, - response.status_code, + else: + if app._should_log_result(False): + app.logger.error( + "Website check for %s: FAIL (HTTP Status: %s)", + website, + status_code, + ) + + if app.enable_prometheus: + app.website_status.labels(url=website).set(0) + app.website_errors.labels(url=website).inc() + + await app.send_notification( + f"Website returned HTTP {status_code}", + status="error", + check_type="Website", + target=website, ) - - if app.enable_prometheus: - app.website_status.labels(url=website).set(0) - app.website_errors.labels(url=website).inc() - - app.send_notification( - f"Website returned HTTP {response.status_code}", - status="error", - check_type="Website", - target=website, - ) - except requests.RequestException as exc: + except (aiohttp.ClientError, asyncio.TimeoutError) as exc: if app._should_log_result(False): app.logger.error("Website check for %s: FAIL (%s)", website, exc) @@ -284,15 +316,18 @@ def run(self) -> None: app.website_status.labels(url=website).set(0) app.website_errors.labels(url=website).inc() - app.send_notification( + await app.send_notification( f"Failed to reach website: {exc}", status="error", check_type="Website", target=website, ) + finally: + if local_session: + await session.close() - # Record the result in the failure tracker - app.failure_tracker.record_result(f"website:{website}", success) + # Record the result in the failure tracker + app.failure_tracker.record_result(f"website:{website}", success) class SSLCheck: @@ -303,7 +338,7 @@ def __init__(self, deps: CheckDependencies): def app(self): return self.deps.app - def run(self) -> None: + async def run(self) -> None: app = self.app if not app.enable_ssl_check or not app.ssl_check_domains: return @@ -311,53 +346,65 @@ def run(self) -> None: if not (app.show_only_success or app.show_only_failure): app.logger.info("Starting SSL certificate checks...") + tasks = [] for domain in app.ssl_check_domains: - # Check if we should skip this target due to backoff/circuit breaker - if not app.failure_tracker.should_check(f"ssl:{domain}"): - if app.verbose: - app.logger.debug("Skipping SSL check for %s (in backoff/circuit open)", domain) - continue + tasks.append(self._check_ssl(domain)) + + await asyncio.gather(*tasks) - success = False - try: - host, port = self._parse_domain(domain) - days_remaining = self._get_ssl_days_remaining(host, port) - - if days_remaining is None: - continue - - if days_remaining < 0: - message = f"SSL certificate for {domain} has expired" - level = "error" - elif days_remaining <= app.ssl_critical_days: - message = f"SSL certificate for {domain} expires in {days_remaining} days (CRITICAL)" - level = "error" - elif days_remaining <= app.ssl_warn_days: - message = f"SSL certificate for {domain} expires in {days_remaining} days (WARNING)" - level = "warning" - else: - message = f"SSL certificate for {domain} is valid for {days_remaining} more days" - level = "ok" - success = True - - self._handle_result(domain, days_remaining, message, level) - except Exception as exc: # pylint: disable=broad-except - if app._should_log_result(False): - app.logger.error("SSL check for %s failed: %s", domain, exc) + async def _check_ssl(self, domain: str) -> None: + app = self.app + # Check if we should skip this target due to backoff/circuit breaker + if not app.failure_tracker.should_check(f"ssl:{domain}"): + if app.verbose: + app.logger.debug("Skipping SSL check for %s (in backoff/circuit open)", domain) + return - if app.enable_prometheus: - app.ssl_status.labels(domain=domain).set(0) - app.ssl_errors.labels(domain=domain).inc() + success = False + try: + host, port = self._parse_domain(domain) + + # Run the blocking SSL check in a thread executor + loop = asyncio.get_running_loop() + days_remaining = await loop.run_in_executor( + None, self._get_ssl_days_remaining, host, port + ) - app.send_notification( - f"SSL check failed: {exc}", - status="error", - check_type="SSL", - target=domain, - ) + if days_remaining is None: + return + + if days_remaining < 0: + message = f"SSL certificate for {domain} has expired" + level = "error" + elif days_remaining <= app.ssl_critical_days: + message = f"SSL certificate for {domain} expires in {days_remaining} days (CRITICAL)" + level = "error" + elif days_remaining <= app.ssl_warn_days: + message = f"SSL certificate for {domain} expires in {days_remaining} days (WARNING)" + level = "warning" + else: + message = f"SSL certificate for {domain} is valid for {days_remaining} more days" + level = "ok" + success = True + + await self._handle_result(domain, days_remaining, message, level) + except Exception as exc: # pylint: disable=broad-except + if app._should_log_result(False): + app.logger.error("SSL check for %s failed: %s", domain, exc) + + if app.enable_prometheus: + app.ssl_status.labels(domain=domain).set(0) + app.ssl_errors.labels(domain=domain).inc() + + await app.send_notification( + f"SSL check failed: {exc}", + status="error", + check_type="SSL", + target=domain, + ) - # Record the result in the failure tracker - app.failure_tracker.record_result(f"ssl:{domain}", success) + # Record the result in the failure tracker + app.failure_tracker.record_result(f"ssl:{domain}", success) def _parse_domain(self, domain: str) -> tuple[str, int]: if ":" in domain: @@ -366,17 +413,22 @@ def _parse_domain(self, domain: str) -> tuple[str, int]: return domain, 443 def _get_ssl_days_remaining(self, host: str, port: int) -> Optional[int]: + # This is a blocking function, intended to be run in an executor context = ssl.create_default_context() expire_time: Optional[datetime] = None - with socket.create_connection((host, port), timeout=5) as sock: - with context.wrap_socket(sock, server_hostname=host) as wrapped: - cert: Dict[str, Any] = wrapped.getpeercert() or {} + try: + with socket.create_connection((host, port), timeout=5) as sock: + with context.wrap_socket(sock, server_hostname=host) as wrapped: + cert: Dict[str, Any] = wrapped.getpeercert() or {} - not_after = cert.get("notAfter") - if not not_after: - return None + not_after = cert.get("notAfter") + if not not_after: + return None - expire_time = datetime.strptime(str(not_after), "%b %d %H:%M:%S %Y %Z") + expire_time = datetime.strptime(str(not_after), "%b %d %H:%M:%S %Y %Z") + except Exception as e: + self.app.logger.debug("SSL handshake failed for %s:%s: %s", host, port, e) + raise if expire_time is None: return None @@ -387,7 +439,7 @@ def _get_ssl_days_remaining(self, host: str, port: int) -> Optional[int]: return delta.days - def _handle_result(self, domain: str, days_remaining: int, message: str, level: str) -> None: + async def _handle_result(self, domain: str, days_remaining: int, message: str, level: str) -> None: app = self.app if level == "ok": @@ -410,7 +462,7 @@ def _handle_result(self, domain: str, days_remaining: int, message: str, level: if level != "ok": app.ssl_errors.labels(domain=domain).inc() - app.send_notification( + await app.send_notification( message, status=status, check_type="SSL", diff --git a/pingpanda_core/notifications.py b/pingpanda_core/notifications.py index 19387bb..51c518d 100644 --- a/pingpanda_core/notifications.py +++ b/pingpanda_core/notifications.py @@ -4,12 +4,12 @@ import logging import socket -import time +import asyncio from dataclasses import dataclass from datetime import datetime from typing import Any, Dict, Optional -import requests +import aiohttp from .persistence import PersistenceManager @@ -45,7 +45,14 @@ def __init__( self.status_dir = persistence.status_dir self._failure_counts: Dict[str, int] = {} - def notify(self, message: str, status: str, check_type: str, target: str) -> None: + async def notify( + self, + message: str, + status: str, + check_type: str, + target: str, + session: Optional[aiohttp.ClientSession] = None, + ) -> None: if not self._should_notify(check_type, target, status): return @@ -61,19 +68,31 @@ def notify(self, message: str, status: str, check_type: str, target: str) -> Non f"*Message:* {message}" ) - sent = False - + tasks = [] + channels = [] if self.settings.slack_webhook_url: - sent |= self._send_slack(title, formatted_message, status) + tasks.append(self._send_slack(title, formatted_message, status, session)) + channels.append("Slack") if self.settings.teams_webhook_url: - sent |= self._send_teams(title, formatted_message, status) + tasks.append(self._send_teams(title, formatted_message, status, session)) + channels.append("Teams") if self.settings.discord_webhook_url: - sent |= self._send_discord(title, formatted_message, status) + tasks.append(self._send_discord(title, formatted_message, status, session)) + channels.append("Discord") - if not sent: + if not tasks: self.logger.debug("Notification suppressed; no webhook endpoints configured") + return + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for idx, result in enumerate(results): + if isinstance(result, Exception): + self.logger.error( + "Notification to %s failed: %r", channels[idx], result + ) def _status_key(self, check_type: str, target: str) -> str: return f"{check_type}_{target}" @@ -98,7 +117,13 @@ def _should_notify(self, check_type: str, target: str, status: str) -> bool: return False - def _send_slack(self, title: str, message: str, status: str) -> bool: + async def _send_slack( + self, + title: str, + message: str, + status: str, + session: Optional[aiohttp.ClientSession] = None, + ) -> bool: color = "good" if status == "ok" else "danger" payload: Dict[str, Any] = { "text": title, @@ -119,13 +144,20 @@ def _send_slack(self, title: str, message: str, status: str) -> bool: if self.settings.slack_icon_emoji: payload["icon_emoji"] = self.settings.slack_icon_emoji - return self._post_with_retries( + return await self._post_with_retries( self.settings.slack_webhook_url, payload, "Slack", + session=session, ) - def _send_teams(self, title: str, message: str, status: str) -> bool: + async def _send_teams( + self, + title: str, + message: str, + status: str, + session: Optional[aiohttp.ClientSession] = None, + ) -> bool: color = "00FF00" if status == "ok" else "FF0000" payload = { "@type": "MessageCard", @@ -136,13 +168,20 @@ def _send_teams(self, title: str, message: str, status: str) -> bool: "text": message.replace("*", ""), } - return self._post_with_retries( + return await self._post_with_retries( self.settings.teams_webhook_url, payload, "Microsoft Teams", + session=session, ) - def _send_discord(self, title: str, message: str, status: str) -> bool: + async def _send_discord( + self, + title: str, + message: str, + status: str, + session: Optional[aiohttp.ClientSession] = None, + ) -> bool: color = 65280 if status == "ok" else 16711680 payload: Dict[str, Any] = { "embeds": [ @@ -159,35 +198,53 @@ def _send_discord(self, title: str, message: str, status: str) -> bool: if self.settings.discord_avatar_url: payload["avatar_url"] = self.settings.discord_avatar_url - return self._post_with_retries( + return await self._post_with_retries( self.settings.discord_webhook_url, payload, "Discord", + session=session, ) - def _post_with_retries( + async def _post_with_retries( self, url: Optional[str], payload: Dict[str, Any], service: str, headers: Optional[Dict[str, str]] = None, + session: Optional[aiohttp.ClientSession] = None, ) -> bool: if not url: return False + # Use provided session or create a temporary one + if session: + return await self._attempt_post(session, url, payload, service, headers) + + async with aiohttp.ClientSession() as local_session: + return await self._attempt_post(local_session, url, payload, service, headers) + + async def _attempt_post( + self, + session: aiohttp.ClientSession, + url: str, + payload: Dict[str, Any], + service: str, + headers: Optional[Dict[str, str]], + ) -> bool: for attempt in range(1, self.settings.retry_attempts + 1): try: - response = requests.post(url, json=payload, headers=headers, timeout=5) - if response.status_code < 400: - return True - - self.logger.warning( - "%s webhook returned status %s: %s", - service, - response.status_code, - response.text[:500], - ) - except requests.RequestException as exc: + async with session.post(url, json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as response: + if response.status < 400: + return True + + text = await response.text() + self.logger.warning( + "%s webhook returned status %s: %s", + service, + response.status, + text[:500], + ) + except aiohttp.ClientError as exc: self.logger.warning( "%s notification attempt %s/%s failed: %s", service, @@ -197,7 +254,7 @@ def _post_with_retries( ) if attempt < self.settings.retry_attempts and self.settings.retry_backoff > 0: - time.sleep(self.settings.retry_backoff * attempt) + await asyncio.sleep(self.settings.retry_backoff * attempt) self.logger.error( "Failed to send %s notification after %s attempts", diff --git a/requirements.txt b/requirements.txt index da8d1b8..2612fe8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ -pythonping>=1.1.4 -requests>=2.32.3 -pytest>=7.4.0 +aiohttp>=3.9.0 +aiodns>=3.1.0 +aioping>=0.4.0 slack-sdk>=3.34.0 -prometheus-client>=0.17.0 \ No newline at end of file +prometheus-client>=0.17.0 +pytest-asyncio>=0.21.0 \ No newline at end of file diff --git a/tests/test_checks.py b/tests/test_checks.py index df9d5c0..110bb8e 100644 --- a/tests/test_checks.py +++ b/tests/test_checks.py @@ -1,6 +1,7 @@ import logging -import socket -import time +import pytest +import asyncio +from unittest.mock import MagicMock, AsyncMock import pingpanda_core.checks as checks_module from pingpanda_core.backoff import FailureTracker @@ -32,11 +33,13 @@ def __init__(self): self.flap_threshold = 3 self.flap_window_seconds = 300 self.failure_tracker = FailureTracker(enable_backoff=False) + self.http_session = None + self.dns_resolver = None def _should_log_result(self, is_success): return True - def send_notification(self, message, status, check_type, target): + async def send_notification(self, message, status, check_type, target): self.notifications.append( { "message": message, @@ -47,74 +50,87 @@ def send_notification(self, message, status, check_type, target): ) -def test_dns_check_success(monkeypatch): +@pytest.mark.asyncio +async def test_dns_check_success(monkeypatch): app = DummyApp() app.enable_dns = True app.domains = ["example.com"] - monkeypatch.setattr(time, "sleep", lambda _: None) - monkeypatch.setattr(socket, "gethostbyname", lambda domain: "93.184.216.34") + # Mock aiodns resolver + mock_resolver = MagicMock() + # query is async + mock_resolver.query = AsyncMock(return_value="93.184.216.34") + app.dns_resolver = mock_resolver - DNSCheck(CheckDependencies(app=app, stats=None)).run() + await DNSCheck(CheckDependencies(app=app, stats=None)).run() assert len(app.notifications) == 1 assert app.notifications[0]["status"] == "ok" assert app.notifications[0]["type"] == "DNS" -def test_ping_check_failure(monkeypatch): +@pytest.mark.asyncio +async def test_ping_check_failure(monkeypatch): app = DummyApp() app.enable_ping = True app.ping_ips = ["1.1.1.1"] - monkeypatch.setattr(time, "sleep", lambda _: None) + # Mock aioping.ping to raise exception + monkeypatch.setattr(checks_module.aioping, "ping", AsyncMock(side_effect=Exception("Ping failed"))) + monkeypatch.setattr(asyncio, "sleep", AsyncMock()) - class FailingPing: - def success(self): - return False - - @property - def rtt_avg_ms(self): - return 0.0 - - monkeypatch.setattr(checks_module.pythonping, "ping", lambda *args, **kwargs: FailingPing()) - - PingCheck(CheckDependencies(app=app, stats=None)).run() + await PingCheck(CheckDependencies(app=app, stats=None)).run() assert len(app.notifications) == 1 assert app.notifications[0]["status"] == "error" assert app.notifications[0]["type"] == "Ping" -def test_website_check_non_success(monkeypatch): +@pytest.mark.asyncio +async def test_website_check_non_success(monkeypatch): app = DummyApp() app.enable_website_check = True app.websites = ["https://service"] app.success_http_codes = [200] - class FakeResponse: - def __init__(self, status_code): - self.status_code = status_code - - monkeypatch.setattr(checks_module.requests, "get", lambda *args, **kwargs: FakeResponse(500)) + # Mock aiohttp session + mock_session = MagicMock() + mock_response = AsyncMock() + mock_response.status = 500 + # __aenter__ returns the response + mock_response.__aenter__.return_value = mock_response + mock_response.__aexit__.return_value = None + + mock_session.get.return_value = mock_response + app.http_session = mock_session - WebsiteCheck(CheckDependencies(app=app, stats=None)).run() + await WebsiteCheck(CheckDependencies(app=app, stats=None)).run() assert len(app.notifications) == 1 assert app.notifications[0]["status"] == "error" assert app.notifications[0]["type"] == "Website" -def test_ssl_check_warning(monkeypatch): +@pytest.mark.asyncio +async def test_ssl_check_warning(monkeypatch): app = DummyApp() app.enable_ssl_check = True app.ssl_check_domains = ["example.com"] app.ssl_warn_days = 30 app.ssl_critical_days = 7 + # Mock _get_ssl_days_remaining which is run in executor + # Since we mock the method on the instance, we can just make it return the value + # But run_in_executor calls it. + + # We can mock loop.run_in_executor + # Or we can mock SSLCheck._get_ssl_days_remaining + + # Since run_in_executor executes the function, if we mock the function it should work. + monkeypatch.setattr(SSLCheck, "_get_ssl_days_remaining", lambda self, host, port: 10) - SSLCheck(CheckDependencies(app=app, stats=None)).run() + await SSLCheck(CheckDependencies(app=app, stats=None)).run() assert len(app.notifications) == 1 assert app.notifications[0]["status"] == "error" # warning mapped to error notifications diff --git a/tests/test_cli.py b/tests/test_cli.py index 4104795..2234572 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,5 @@ import sys - +import asyncio import pingpanda @@ -13,7 +13,7 @@ class DummyMonitor: def __init__(self, config): captured["config"] = config - def run(self): + async def run(self): captured["ran"] = True monkeypatch.setenv("EXISTING", "value") diff --git a/tests/test_notifications.py b/tests/test_notifications.py index 2020958..2700a3a 100644 --- a/tests/test_notifications.py +++ b/tests/test_notifications.py @@ -1,11 +1,13 @@ import logging +import pytest from pathlib import Path from pingpanda_core.notifications import NotificationManager, NotificationSettings from pingpanda_core.persistence import PersistenceManager, StatsPersistenceSettings -def test_notification_threshold_and_recovery(monkeypatch, tmp_path): +@pytest.mark.asyncio +async def test_notification_threshold_and_recovery(monkeypatch, tmp_path): persistence = PersistenceManager( logging.getLogger("pingpanda.tests.notifications"), base_dir=str(tmp_path), @@ -22,7 +24,7 @@ def test_notification_threshold_and_recovery(monkeypatch, tmp_path): sent_statuses = [] - def fake_slack(title, message, status): + async def fake_slack(title, message, status, session=None): sent_statuses.append(status) return True @@ -31,14 +33,14 @@ def fake_slack(title, message, status): status_key = "DNS_example.com" status_path = Path(persistence.status_file_path(status_key)) - manager.notify("down", "error", "DNS", "example.com") + await manager.notify("down", "error", "DNS", "example.com") assert not sent_statuses assert status_path.read_text(encoding="utf-8") == "1" - manager.notify("still down", "error", "DNS", "example.com") + await manager.notify("still down", "error", "DNS", "example.com") assert sent_statuses == ["error"] assert status_path.read_text(encoding="utf-8") == "2" - manager.notify("recovered", "ok", "DNS", "example.com") + await manager.notify("recovered", "ok", "DNS", "example.com") assert sent_statuses[-1] == "ok" assert not status_path.exists()