From e5050941b89acd9fabc9b83f6bfd5a4e803668ea Mon Sep 17 00:00:00 2001 From: "cto-new[bot]" <140088366+cto-new[bot]@users.noreply.github.com> Date: Fri, 5 Dec 2025 11:46:24 +0000 Subject: [PATCH] refactor(sentiment): modularize pipeline with config, fetchers, repository, and async execution Centralize configuration, introduce a fetcher abstraction with concrete TwitterFetcher and RedditFetcher, and implement a modular sentiment analysis service. Add a repository layer with connection pooling and migrations, and orchestrate with an async pipeline that respects rate limits and batching. Improve data validation, sanitization, timezone normalization, and duplicate handling. Enrich logging and typing; provide CLI entry point and tests/docs scaffolding. --- .gitignore | 20 ++ README.md | 87 +++++++- requirements.txt | 1 + sentiment/__init__.py | 6 + sentiment/analyzer.py | 57 ++++++ sentiment/cli.py | 122 ++++++++++++ sentiment/config.py | 129 ++++++++++++ sentiment/exceptions.py | 31 +++ sentiment/exporter.py | 63 ++++++ sentiment/fetchers/__init__.py | 1 + sentiment/fetchers/base.py | 90 +++++++++ sentiment/fetchers/reddit.py | 58 ++++++ sentiment/fetchers/twitter.py | 49 +++++ sentiment/logging_utils.py | 32 +++ sentiment/metrics.py | 56 ++++++ sentiment/models.py | 53 +++++ sentiment/pipeline.py | 124 ++++++++++++ sentiment/rate_limiter.py | 37 ++++ sentiment/repository.py | 352 +++++++++++++++++++++++++++++++++ sentiment/retry.py | 42 ++++ sentiment/sanitizer.py | 59 ++++++ tests/__init__.py | 0 tests/test_analyzer.py | 39 ++++ tests/test_config.py | 43 ++++ tests/test_pipeline.py | 89 +++++++++ tests/test_repository.py | 63 ++++++ tests/test_sanitizer.py | 33 ++++ 27 files changed, 1735 insertions(+), 1 deletion(-) create mode 100644 .gitignore create mode 100644 requirements.txt create mode 100644 sentiment/__init__.py create mode 100644 sentiment/analyzer.py create mode 100644 sentiment/cli.py create mode 100644 sentiment/config.py create mode 100644 sentiment/exceptions.py create mode 100644 sentiment/exporter.py create mode 100644 sentiment/fetchers/__init__.py create mode 100644 sentiment/fetchers/base.py create mode 100644 sentiment/fetchers/reddit.py create mode 100644 sentiment/fetchers/twitter.py create mode 100644 sentiment/logging_utils.py create mode 100644 sentiment/metrics.py create mode 100644 sentiment/models.py create mode 100644 sentiment/pipeline.py create mode 100644 sentiment/rate_limiter.py create mode 100644 sentiment/repository.py create mode 100644 sentiment/retry.py create mode 100644 sentiment/sanitizer.py create mode 100644 tests/__init__.py create mode 100644 tests/test_analyzer.py create mode 100644 tests/test_config.py create mode 100644 tests/test_pipeline.py create mode 100644 tests/test_repository.py create mode 100644 tests/test_sanitizer.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a4f0157 --- /dev/null +++ b/.gitignore @@ -0,0 +1,20 @@ +__pycache__/ +*.py[cod] +*.log +*.sqlite3 +*.db +.env +.env.* +.venv/ +.idea/ +.vscode/ +.mypy_cache/ +.pytest_cache/ +coverage.xml +htmlcov/ +exports/ +metrics.json +dist/ +build/ +*.egg-info/ +db_utils/__pycache__/ diff --git a/README.md b/README.md index 7b60ffb..eb918cb 100644 --- a/README.md +++ b/README.md @@ -1 +1,86 @@ -# Citadel \ No newline at end of file +# Sentiment Analysis Pipeline + +## Overview + +This repository now ships a production-ready sentiment analysis pipeline that ingests +posts from social sources (Twitter and Reddit), scores their sentiment, persists the +results, exports structured data, and records runtime metrics. The pipeline was +refactored to emphasise reliability, observability, and maintainability. + +## Key Features + +- **Centralised configuration** via `AppConfig` with strict validation of required + environment variables, configurable timeouts, retry counts, and batching controls. +- **Structured logging** across the stack with DEBUG/INFO/WARNING/ERROR semantics. +- **Robust API access** featuring rate limiting, exponential backoff retries, + response validation, and sanitisation of fetched payloads. +- **Async orchestration** that fetches from multiple sources concurrently while + respecting rate limits and tracking batch progress. +- **Modular architecture** with dedicated components for fetchers, analysis, + persistence, exporting, and metrics. +- **Database reliability** through a repository that uses connection pooling, + schema migrations, context-managed operations, and duplicate detection. +- **Data hygiene** by sanitising text, validating timestamps, filtering duplicates, + and converting everything to a consistent timezone. +- **Operational insights** via metrics for processed posts, duplicates, failures, + and per-sentiment distribution alongside optional JSON/CSV exports. + +## Environment Variables + +| Variable | Description | +| --- | --- | +| `DATABASE_URL` | Database connection string. Supports `sqlite:///path/to.db` or `postgres://user:pass@host:port/db`. | +| `TWITTER_BEARER_TOKEN` | Required when enabling the Twitter fetcher. | +| `REDDIT_CLIENT_ID` / `REDDIT_CLIENT_SECRET` | Required when enabling the Reddit fetcher. | +| `SENTIMENT_SOURCES` | Comma-separated list of fetchers to enable (`twitter,reddit`). Defaults to both. | +| `REQUEST_TIMEOUT` | HTTP request timeout in seconds (default `10`). | +| `MAX_RETRIES` | Number of API retry attempts (default `3`). | +| `RETRY_BACKOFF_FACTOR` | Exponential backoff multiplier (default `2`). | +| `BATCH_SIZE` | Number of posts processed concurrently in each batch (default `25`). | +| `RATE_LIMIT_PER_MINUTE` | Maximum API calls per minute (default `60`). | +| `EXPORT_PATH` | Directory for exported datasets (default `exports`). | +| `METRICS_EXPORT_PATH` | Output path for metrics JSON (default `metrics.json`). | + +## Running the Pipeline + +1. Ensure Python 3.11+ is available and install dependencies (e.g. `pip install -r requirements.txt`). +2. Export the required environment variables shown above. +3. Launch the CLI: + +```bash +python -m sentiment.cli --query "ai" --limit 50 --export --log-level INFO +``` + +The CLI validates configuration, spins up the async pipeline, streams logs, and writes +metrics/exports on completion. + +## Testing + +Unit tests cover configuration, sanitisation, the sentiment analyser, repository, +and pipeline orchestration. Execute them with: + +```bash +python -m unittest discover tests +``` + +## Project Structure + +``` +sentiment/ + analyzer.py # Sentiment scoring service + cli.py # Command-line entry point + config.py # Dataclass-backed configuration + exporter.py # JSON/CSV export helpers + fetchers/ # Twitter & Reddit API clients derived from BaseFetcher + metrics.py # Metrics tracking utilities + pipeline.py # Async orchestrator + repository.py # DB access layer with pooling + migrations + sanitizer.py # Text/time validation & sanitisation helpers +``` + +## Notes + +- Postgres deployments require `psycopg2`. SQLite is supported out-of-the-box for + local testing and is used inside the automated unit tests. +- Metrics are exported after each pipeline run and can be ingested by dashboards or + alerting tooling as needed. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..83ec125 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +psycopg2-binary>=2.9 diff --git a/sentiment/__init__.py b/sentiment/__init__.py new file mode 100644 index 0000000..338beec --- /dev/null +++ b/sentiment/__init__.py @@ -0,0 +1,6 @@ +"""High level sentiment analysis toolkit.""" + +from .config import AppConfig # noqa: F401 +from .analyzer import SentimentAnalyzer # noqa: F401 +from .repository import DataRepository # noqa: F401 +from .pipeline import SentimentPipeline # noqa: F401 diff --git a/sentiment/analyzer.py b/sentiment/analyzer.py new file mode 100644 index 0000000..22956ad --- /dev/null +++ b/sentiment/analyzer.py @@ -0,0 +1,57 @@ +"""Lightweight sentiment analyzer service.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Iterable, Set + +from .config import AppConfig +from .models import SentimentLabel, SentimentResult + + +@dataclass +class SentimentAnalyzer: + """Performs rule-based sentiment analysis on sanitized text.""" + + config: AppConfig + positive_words: Set[str] = field(default_factory=lambda: { + "great", + "awesome", + "good", + "love", + "happy", + "fantastic", + "win", + "excellent", + }) + negative_words: Set[str] = field(default_factory=lambda: { + "bad", + "terrible", + "hate", + "sad", + "awful", + "fail", + "poor", + "angry", + }) + + def analyze(self, text: str) -> SentimentResult: + """Return the sentiment label for the given text.""" + + tokens = [token.lower() for token in text.split() if token] + score = 0 + for token in tokens: + if token in self.positive_words: + score += 1 + elif token in self.negative_words: + score -= 1 + + normalized_score = score / max(len(tokens), 1) + if normalized_score >= self.config.sentiment_positive_threshold: + label = SentimentLabel.POSITIVE + elif normalized_score <= self.config.sentiment_negative_threshold: + label = SentimentLabel.NEGATIVE + else: + label = SentimentLabel.NEUTRAL + + return SentimentResult(score=normalized_score, label=label) diff --git a/sentiment/cli.py b/sentiment/cli.py new file mode 100644 index 0000000..7811157 --- /dev/null +++ b/sentiment/cli.py @@ -0,0 +1,122 @@ +"""Command line interface for the sentiment pipeline.""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import sys +from typing import List + +from .analyzer import SentimentAnalyzer +from .config import AppConfig +from .exceptions import ConfigurationError +from .exporter import DataExporter +from .fetchers.reddit import RedditFetcher +from .fetchers.twitter import TwitterFetcher +from .logging_utils import configure_logging +from .metrics import MetricsTracker +from .pipeline import SentimentPipeline +from .repository import DataRepository + + +async def async_main(args: argparse.Namespace) -> int: + """Run the async pipeline entry point.""" + + try: + config = AppConfig.from_env() + except ConfigurationError as exc: + logging.getLogger("sentiment").error("Configuration error: %s", exc) + return 1 + + configure_logging(args.log_level) + logger = logging.getLogger("sentiment.cli") + + fetchers: List = [] + if "twitter" in config.enabled_sources: + try: + fetchers.append(TwitterFetcher(config)) + except Exception as exc: + logger.error("Twitter fetcher initialization failed: %s", exc) + if "reddit" in config.enabled_sources: + try: + fetchers.append(RedditFetcher(config)) + except Exception as exc: + logger.error("Reddit fetcher initialization failed: %s", exc) + + if not fetchers: + logger.error("No fetchers available. Ensure credentials are configured correctly.") + return 1 + + repository = DataRepository(config) + analyzer = SentimentAnalyzer(config) + exporter = DataExporter(config.export_path) + metrics = MetricsTracker() + + pipeline = SentimentPipeline( + config=config, + repository=repository, + analyzer=analyzer, + fetchers=fetchers, + exporter=exporter, + metrics=metrics, + ) + + try: + result = await pipeline.run( + query=args.query, + limit=args.limit, + sentiment_filter=args.sentiment, + sort_by=args.sort_by, + descending=args.descending, + export=args.export, + ) + logger.info( + "Processed %s posts (duplicates: %s)", + result.metrics["posts_processed"], + result.metrics["duplicates_skipped"], + ) + return 0 + finally: + repository.close() + + +def build_parser() -> argparse.ArgumentParser: + """Create an argument parser for the CLI.""" + + parser = argparse.ArgumentParser(description="Social sentiment pipeline") + parser.add_argument("--query", default="ai", help="Search query") + parser.add_argument("--limit", type=int, default=25, help="Max items per fetcher") + parser.add_argument( + "--sentiment", + nargs="*", + choices=["positive", "negative", "neutral"], + help="Filter output posts by sentiment", + ) + parser.add_argument( + "--sort-by", + choices=["created_at", "author", "sentiment_score"], + default="created_at", + help="Sort column", + ) + parser.add_argument("--descending", action="store_true", help="Sort descending") + parser.add_argument("--export", action="store_true", help="Export processed posts") + parser.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Log verbosity", + ) + return parser + + +def main(argv: List[str] | None = None) -> int: + """Synchronous CLI entry point.""" + + parser = build_parser() + args = parser.parse_args(argv) + return asyncio.run(async_main(args)) + + +if __name__ == "__main__": # pragma: no cover - manual execution + sys.exit(main()) diff --git a/sentiment/config.py b/sentiment/config.py new file mode 100644 index 0000000..c33caf9 --- /dev/null +++ b/sentiment/config.py @@ -0,0 +1,129 @@ +"""Application configuration management.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from typing import List + +from .exceptions import ConfigurationError + + +@dataclass(slots=True) +class AppConfig: + """Runtime configuration for the sentiment pipeline.""" + + database_url: str + twitter_bearer_token: str | None = None + reddit_client_id: str | None = None + reddit_client_secret: str | None = None + request_timeout: float = 10.0 + fetch_timeout: float = 15.0 + max_retries: int = 3 + retry_backoff_factor: float = 2.0 + rate_limit_per_minute: int = 60 + rate_limit_period: float = 60.0 + batch_size: int = 25 + concurrency: int = 5 + export_path: str = "exports" + export_format: str = "json" + metrics_export_path: str = "metrics.json" + min_pool_size: int = 1 + max_pool_size: int = 5 + schema_version: int = 1 + timezone: str = "UTC" + duplicate_window_minutes: int = 120 + sentiment_positive_threshold: float = 0.2 + sentiment_negative_threshold: float = -0.2 + sources: List[str] = field(default_factory=lambda: ["twitter", "reddit"]) + + def __post_init__(self) -> None: + """Normalize configured sources.""" + + self.sources = [source.strip().lower() for source in self.sources if source.strip()] + + @classmethod + def from_env(cls) -> "AppConfig": + """Build the configuration from environment variables.""" + + env = os.environ + sources = env.get("SENTIMENT_SOURCES") + parsed_sources = ( + [src.strip().lower() for src in sources.split(",") if src.strip()] + if sources + else ["twitter", "reddit"] + ) + + config = cls( + database_url=env.get("DATABASE_URL", "").strip(), + twitter_bearer_token=env.get("TWITTER_BEARER_TOKEN"), + reddit_client_id=env.get("REDDIT_CLIENT_ID"), + reddit_client_secret=env.get("REDDIT_CLIENT_SECRET"), + request_timeout=float(env.get("REQUEST_TIMEOUT", 10.0)), + fetch_timeout=float(env.get("FETCH_TIMEOUT", 15.0)), + max_retries=int(env.get("MAX_RETRIES", 3)), + retry_backoff_factor=float(env.get("RETRY_BACKOFF_FACTOR", 2.0)), + rate_limit_per_minute=int(env.get("RATE_LIMIT_PER_MINUTE", 60)), + rate_limit_period=float(env.get("RATE_LIMIT_PERIOD", 60.0)), + batch_size=int(env.get("BATCH_SIZE", 25)), + concurrency=int(env.get("CONCURRENCY", 5)), + export_path=env.get("EXPORT_PATH", "exports"), + export_format=env.get("EXPORT_FORMAT", "json"), + metrics_export_path=env.get("METRICS_EXPORT_PATH", "metrics.json"), + min_pool_size=int(env.get("DB_POOL_MIN", 1)), + max_pool_size=int(env.get("DB_POOL_MAX", 5)), + schema_version=int(env.get("SCHEMA_VERSION", 1)), + timezone=env.get("TIMEZONE", "UTC"), + duplicate_window_minutes=int(env.get("DUPLICATE_WINDOW_MINUTES", 120)), + sentiment_positive_threshold=float(env.get("SENTIMENT_POS_THRESHOLD", 0.2)), + sentiment_negative_threshold=float(env.get("SENTIMENT_NEG_THRESHOLD", -0.2)), + sources=parsed_sources, + ) + config.validate() + return config + + @property + def enabled_sources(self) -> List[str]: + """Return sources that can be instantiated with the available secrets.""" + + enabled: list[str] = [] + for source in self.sources: + if source == "twitter" and self.twitter_bearer_token: + enabled.append(source) + elif ( + source == "reddit" + and self.reddit_client_id + and self.reddit_client_secret + ): + enabled.append(source) + return enabled + + def validate(self) -> None: + """Validate that the configuration satisfies all prerequisites.""" + + missing: list[str] = [] + if not self.database_url: + missing.append("DATABASE_URL") + + normalized_sources = {source.strip().lower() for source in self.sources} + if "twitter" in normalized_sources and not self.twitter_bearer_token: + missing.append("TWITTER_BEARER_TOKEN") + if "reddit" in normalized_sources: + if not self.reddit_client_id: + missing.append("REDDIT_CLIENT_ID") + if not self.reddit_client_secret: + missing.append("REDDIT_CLIENT_SECRET") + + if missing: + raise ConfigurationError( + "Missing required environment variables: " + ", ".join(sorted(set(missing))) + ) + + if self.max_pool_size < self.min_pool_size: + raise ConfigurationError("DB_POOL_MAX must be >= DB_POOL_MIN") + + if self.batch_size <= 0: + raise ConfigurationError("BATCH_SIZE must be a positive integer") + + if not normalized_sources: + raise ConfigurationError("At least one data source must be specified") diff --git a/sentiment/exceptions.py b/sentiment/exceptions.py new file mode 100644 index 0000000..5b1b46e --- /dev/null +++ b/sentiment/exceptions.py @@ -0,0 +1,31 @@ +"""Custom exceptions for the sentiment analysis pipeline.""" + +from __future__ import annotations + + +class SentimentError(Exception): + """Base exception for sentiment related errors.""" + + +class ConfigurationError(SentimentError): + """Raised when the application configuration is invalid.""" + + +class APIFetchError(SentimentError): + """Raised when an upstream API call fails.""" + + +class DatabaseError(SentimentError): + """Raised when database operations fail.""" + + +class DuplicatePostError(DatabaseError): + """Raised when attempting to insert an already persisted post.""" + + +class RateLimitError(SentimentError): + """Raised when the rate limiter cannot provide capacity in time.""" + + +class ValidationError(SentimentError): + """Raised when data validation fails.""" diff --git a/sentiment/exporter.py b/sentiment/exporter.py new file mode 100644 index 0000000..3dbac67 --- /dev/null +++ b/sentiment/exporter.py @@ -0,0 +1,63 @@ +"""Data export helpers.""" + +from __future__ import annotations + +import csv +import json +from pathlib import Path +from typing import Iterable, List + +from .models import SocialPost + + +class DataExporter: + """Exports processed posts to disk.""" + + def __init__(self, output_dir: str | Path) -> None: + self._output_dir = Path(output_dir) + self._output_dir.mkdir(parents=True, exist_ok=True) + + def export(self, posts: Iterable[SocialPost], *, filename: str, fmt: str = "json") -> Path: + """Export posts to the requested format.""" + + if fmt not in {"json", "csv"}: + raise ValueError("Unsupported export format") + + path = self._output_dir / f"{filename}.{fmt}" + if fmt == "json": + self._export_json(posts, path) + else: + self._export_csv(posts, path) + return path + + def _export_json(self, posts: Iterable[SocialPost], path: Path) -> None: + serialized = [self._serialize_post(post) for post in posts] + with path.open("w", encoding="utf-8") as handle: + json.dump(serialized, handle, indent=2) + + def _export_csv(self, posts: Iterable[SocialPost], path: Path) -> None: + serialized = [self._serialize_post(post) for post in posts] + if not serialized: + path.touch() + return + + with path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=list(serialized[0].keys())) + writer.writeheader() + writer.writerows(serialized) + + @staticmethod + def _serialize_post(post: SocialPost) -> dict: + result = { + "external_id": post.external_id, + "author": post.author, + "source": post.source.value, + "created_at": post.created_at.isoformat(), + "text": post.text, + "url": post.url, + "metadata": post.metadata, + } + if post.sentiment: + result["sentiment_score"] = post.sentiment.score + result["sentiment_label"] = post.sentiment.label.value + return result diff --git a/sentiment/fetchers/__init__.py b/sentiment/fetchers/__init__.py new file mode 100644 index 0000000..0392ce0 --- /dev/null +++ b/sentiment/fetchers/__init__.py @@ -0,0 +1 @@ +"""Fetcher implementations.""" diff --git a/sentiment/fetchers/base.py b/sentiment/fetchers/base.py new file mode 100644 index 0000000..cffbe0c --- /dev/null +++ b/sentiment/fetchers/base.py @@ -0,0 +1,90 @@ +"""Base fetcher definition.""" + +from __future__ import annotations + +import asyncio +import json +import logging +import urllib.error +import urllib.parse +import urllib.request +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, List, Mapping, Optional + +from ..config import AppConfig +from ..exceptions import APIFetchError +from ..models import SocialPost +from ..rate_limiter import AsyncRateLimiter +from ..retry import execute_with_retry + + +class BaseFetcher(ABC): + """Common functionality for remote data fetchers.""" + + def __init__(self, config: AppConfig, logger: logging.Logger | None = None) -> None: + self.config = config + self.logger = logger or logging.getLogger(self.__class__.__name__) + self.rate_limiter = AsyncRateLimiter( + config.rate_limit_per_minute, + config.rate_limit_period, + ) + + @abstractmethod + async def fetch(self, query: str, limit: int) -> List[SocialPost]: + """Fetch posts for the given query.""" + + async def _request(self, url: str, params: Mapping[str, Any], headers: Mapping[str, str]) -> Dict[str, Any]: + """Execute an HTTP GET request with retry and rate limiting.""" + + await self.rate_limiter.acquire() + + async def _call() -> Dict[str, Any]: + return await asyncio.to_thread(self._perform_request, url, params, headers) + + return await execute_with_retry( + _call, + max_attempts=self.config.max_retries, + backoff_factor=self.config.retry_backoff_factor, + retry_exceptions=(APIFetchError, urllib.error.URLError, urllib.error.HTTPError, ConnectionError), + logger=self.logger, + base_delay=1.0, + ) + + def _perform_request( + self, + url: str, + params: Mapping[str, Any], + headers: Mapping[str, str], + ) -> Dict[str, Any]: + """Perform an HTTP GET request synchronously.""" + + try: + query = urllib.parse.urlencode({k: v for k, v in params.items() if v is not None}) + delimiter = "?" if "?" not in url else "&" + target = f"{url}{delimiter}{query}" if query else url + request = urllib.request.Request(target, headers=headers) + with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response: + payload = response.read().decode("utf-8") + except urllib.error.HTTPError as exc: # pragma: no cover - network errors + raise APIFetchError(f"HTTP error {exc.code}: {exc.reason}") from exc + except urllib.error.URLError as exc: # pragma: no cover - network errors + raise APIFetchError(f"Network error: {exc.reason}") from exc + + try: + return json.loads(payload) + except json.JSONDecodeError as exc: + raise APIFetchError("API response was not valid JSON") from exc + + @staticmethod + def _parse_timestamp(value: str | None) -> Optional[datetime]: + if not value: + return None + normalized = value.replace("Z", "+00:00") + return datetime.fromisoformat(normalized) + + @staticmethod + def _validate_keys(payload: Mapping[str, Any], required_keys: List[str]) -> None: + missing = [key for key in required_keys if key not in payload] + if missing: + raise APIFetchError(f"API response missing keys: {', '.join(missing)}") diff --git a/sentiment/fetchers/reddit.py b/sentiment/fetchers/reddit.py new file mode 100644 index 0000000..5539c61 --- /dev/null +++ b/sentiment/fetchers/reddit.py @@ -0,0 +1,58 @@ +"""Reddit fetcher implementation.""" + +from __future__ import annotations + +import base64 +from datetime import datetime, timezone +from typing import Any, Dict, List + +from ..config import AppConfig +from ..exceptions import APIFetchError +from ..models import SocialPost, SocialSource +from ..sanitizer import sanitize_text +from .base import BaseFetcher + + +class RedditFetcher(BaseFetcher): + """Fetch posts from Reddit's public search API.""" + + API_URL = "https://www.reddit.com/search.json" + + def __init__(self, config: AppConfig) -> None: + super().__init__(config) + if not config.reddit_client_id or not config.reddit_client_secret: + raise APIFetchError("Reddit credentials are missing") + + async def fetch(self, query: str, limit: int) -> List[SocialPost]: + params = { + "q": query, + "limit": min(max(limit, 10), 100), + "sort": "new", + "restrict_sr": False, + } + headers = { + "User-Agent": "sentiment-analyzer/1.0", + "Authorization": f"Basic {self._encode_credentials()}", + } + response = await self._request(self.API_URL, params, headers) + self._validate_keys(response, ["data"]) + children = response["data"].get("children", []) + return [self._build_post(child.get("data", {})) for child in children] + + def _encode_credentials(self) -> str: + token = f"{self.config.reddit_client_id}:{self.config.reddit_client_secret}" + encoded = base64.b64encode(token.encode("utf-8")).decode("ascii") + return encoded + + def _build_post(self, payload: Dict[str, Any]) -> SocialPost: + created_utc = payload.get("created_utc") + created_at = datetime.fromtimestamp(created_utc, tz=timezone.utc) if created_utc else datetime.now(tz=timezone.utc) + return SocialPost( + external_id=str(payload.get("id")), + text=sanitize_text(payload.get("selftext") or payload.get("title", "")), + author=str(payload.get("author", "unknown")), + source=SocialSource.REDDIT, + created_at=created_at, + url=f"https://www.reddit.com{payload.get('permalink', '')}", + metadata={"raw": payload}, + ) diff --git a/sentiment/fetchers/twitter.py b/sentiment/fetchers/twitter.py new file mode 100644 index 0000000..ca1c443 --- /dev/null +++ b/sentiment/fetchers/twitter.py @@ -0,0 +1,49 @@ +"""Twitter API fetcher.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, List + +from ..config import AppConfig +from ..exceptions import APIFetchError +from ..models import SocialPost, SocialSource +from ..sanitizer import sanitize_text +from .base import BaseFetcher + + +class TwitterFetcher(BaseFetcher): + """Fetches tweets using the Twitter v2 API.""" + + API_URL = "https://api.twitter.com/2/tweets/search/recent" + + def __init__(self, config: AppConfig) -> None: + super().__init__(config) + if not config.twitter_bearer_token: + raise APIFetchError("Twitter bearer token is missing") + + async def fetch(self, query: str, limit: int) -> List[SocialPost]: + params = { + "query": query, + "max_results": min(max(limit, 10), 100), + "tweet.fields": "created_at,author_id,lang" + } + headers = { + "Authorization": f"Bearer {self.config.twitter_bearer_token}", + } + response = await self._request(self.API_URL, params, headers) + self._validate_keys(response, ["data"]) + return [self._build_post(tweet) for tweet in response.get("data", [])] + + def _build_post(self, tweet: Dict[str, Any]) -> SocialPost: + text = sanitize_text(tweet.get("text", "")) + created_at = self._parse_timestamp(tweet.get("created_at")) or datetime.utcnow() + return SocialPost( + external_id=str(tweet.get("id")), + text=text, + author=str(tweet.get("author_id", "unknown")), + source=SocialSource.TWITTER, + created_at=created_at, + url=f"https://twitter.com/i/web/status/{tweet.get('id')}", + metadata={"raw": tweet}, + ) diff --git a/sentiment/logging_utils.py b/sentiment/logging_utils.py new file mode 100644 index 0000000..cad5ed6 --- /dev/null +++ b/sentiment/logging_utils.py @@ -0,0 +1,32 @@ +"""Logging helpers for the sentiment analysis project.""" + +from __future__ import annotations + +import logging +from typing import Optional + + +def configure_logging(level: str = "INFO") -> logging.Logger: + """Configure and return the application logger. + + Args: + level: Desired log level name. + + Returns: + Configured root logger. + """ + + numeric_level = getattr(logging, level.upper(), logging.INFO) + logger = logging.getLogger("sentiment") + + if not logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + + logger.setLevel(numeric_level) + logger.propagate = False + return logger diff --git a/sentiment/metrics.py b/sentiment/metrics.py new file mode 100644 index 0000000..6747ede --- /dev/null +++ b/sentiment/metrics.py @@ -0,0 +1,56 @@ +"""Metrics tracking for the sentiment pipeline.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict + + +@dataclass +class MetricsTracker: + """Track ingestion metrics.""" + + posts_processed: int = 0 + duplicates_skipped: int = 0 + failures: int = 0 + sentiment_distribution: Dict[str, int] = field(default_factory=lambda: { + "positive": 0, + "negative": 0, + "neutral": 0, + }) + + def record_sentiment(self, label: str) -> None: + """Update counters for sentiment labels.""" + + self.posts_processed += 1 + self.sentiment_distribution[label] = self.sentiment_distribution.get(label, 0) + 1 + + def record_duplicate(self) -> None: + """Mark that a duplicate post has been skipped.""" + + self.duplicates_skipped += 1 + + def record_failure(self) -> None: + """Increment failure counter.""" + + self.failures += 1 + + def export(self, path: str | Path) -> None: + """Persist metrics to disk.""" + + destination = Path(path) + destination.parent.mkdir(parents=True, exist_ok=True) + with destination.open("w", encoding="utf-8") as handle: + json.dump(self.as_dict(), handle, indent=2) + + def as_dict(self) -> Dict[str, int | Dict[str, int]]: + """Serialize metrics to a dictionary.""" + + return { + "posts_processed": self.posts_processed, + "duplicates_skipped": self.duplicates_skipped, + "failures": self.failures, + "sentiment_distribution": self.sentiment_distribution, + } diff --git a/sentiment/models.py b/sentiment/models.py new file mode 100644 index 0000000..b268441 --- /dev/null +++ b/sentiment/models.py @@ -0,0 +1,53 @@ +"""Data models used throughout the sentiment pipeline.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + + +class SocialSource(str, Enum): + """Supported social data sources.""" + + TWITTER = "twitter" + REDDIT = "reddit" + + +class SentimentLabel(str, Enum): + """Possible sentiment labels.""" + + POSITIVE = "positive" + NEGATIVE = "negative" + NEUTRAL = "neutral" + + +@dataclass(slots=True) +class SentimentResult: + """Represents the sentiment of a chunk of text.""" + + score: float + label: SentimentLabel + + +@dataclass(slots=True) +class SocialPost: + """Normalized social post structure.""" + + external_id: str + text: str + author: str + source: SocialSource + created_at: datetime + url: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + sentiment: Optional[SentimentResult] = None + + +@dataclass(slots=True) +class PipelineResult: + """Holds aggregate pipeline output.""" + + posts: List[SocialPost] + metrics: Dict[str, Any] diff --git a/sentiment/pipeline.py b/sentiment/pipeline.py new file mode 100644 index 0000000..1adcfc1 --- /dev/null +++ b/sentiment/pipeline.py @@ -0,0 +1,124 @@ +"""Sentiment pipeline orchestration.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Iterable, List, Sequence + +from .analyzer import SentimentAnalyzer +from .config import AppConfig +from .exceptions import DatabaseError, ValidationError +from .exporter import DataExporter +from .metrics import MetricsTracker +from .models import PipelineResult, SocialPost +from .repository import DataRepository +from .sanitizer import chunk_items, validate_post +from .fetchers.base import BaseFetcher + + +class SentimentPipeline: + """Coordinates fetching, analyzing, and persisting social posts.""" + + def __init__( + self, + *, + config: AppConfig, + repository: DataRepository, + analyzer: SentimentAnalyzer, + fetchers: Sequence[BaseFetcher], + exporter: DataExporter, + metrics: MetricsTracker, + logger: logging.Logger | None = None, + ) -> None: + self.config = config + self.repository = repository + self.analyzer = analyzer + self.fetchers = list(fetchers) + self.exporter = exporter + self.metrics = metrics + self.logger = logger or logging.getLogger(self.__class__.__name__) + self._semaphore = asyncio.Semaphore(config.concurrency) + + async def run( + self, + *, + query: str, + limit: int, + sentiment_filter: Sequence[str] | None = None, + sort_by: str = "created_at", + descending: bool = False, + export: bool = False, + ) -> PipelineResult: + """Execute the sentiment ingestion pipeline.""" + + fetched_posts = await self._gather_posts(query=query, limit=limit) + await self._process_posts(fetched_posts) + + posts = self.repository.fetch_posts( + sentiment_filter=sentiment_filter, + sort_by=sort_by, + descending=descending, + ) + if export: + filename = "sentiment_results" + self.exporter.export(posts, filename=filename, fmt=self.config.export_format) + + self.metrics.export(self.config.metrics_export_path) + return PipelineResult(posts=posts, metrics=self.metrics.as_dict()) + + async def _gather_posts(self, *, query: str, limit: int) -> List[SocialPost]: + tasks = [self._fetch_with_guard(fetcher, query, limit) for fetcher in self.fetchers] + results = await asyncio.gather(*tasks, return_exceptions=True) + posts: list[SocialPost] = [] + for result in results: + if isinstance(result, Exception): + self.metrics.record_failure() + self.logger.error("Fetcher failed", exc_info=result) + continue + posts.extend(result) + return posts + + async def _fetch_with_guard( + self, + fetcher: BaseFetcher, + query: str, + limit: int, + ) -> List[SocialPost]: + async with self._semaphore: + return await asyncio.wait_for( + fetcher.fetch(query, limit), + timeout=self.config.fetch_timeout, + ) + + async def _process_posts(self, posts: Sequence[SocialPost]) -> None: + total = len(posts) + if total == 0: + self.logger.info("No posts returned by fetchers") + return + + processed = 0 + for chunk in chunk_items(posts, self.config.batch_size): + await asyncio.gather(*(self._process_single(post) for post in chunk)) + processed += len(chunk) + self.logger.info("Processed %s/%s posts", processed, total) + + async def _process_single(self, post: SocialPost) -> None: + try: + validated = validate_post(post, self.config.timezone) + sentiment = self.analyzer.analyze(validated.text) + validated.sentiment = sentiment + inserted = await asyncio.to_thread(self.repository.insert_post, validated) + if inserted: + self.metrics.record_sentiment(sentiment.label.value) + else: + self.metrics.record_duplicate() + except ValidationError as exc: + self.metrics.record_failure() + self.logger.warning("Validation failed: %s", exc) + except DatabaseError as exc: + self.metrics.record_failure() + self.logger.error("Database error: %s", exc) + except Exception as exc: # pragma: no cover - safeguard + self.metrics.record_failure() + self.logger.exception("Unexpected processing error: %s", exc) diff --git a/sentiment/rate_limiter.py b/sentiment/rate_limiter.py new file mode 100644 index 0000000..04da9b6 --- /dev/null +++ b/sentiment/rate_limiter.py @@ -0,0 +1,37 @@ +"""Simple asynchronous rate limiter implementation.""" + +from __future__ import annotations + +import asyncio +import time +from collections import deque + +from .exceptions import RateLimitError + + +class AsyncRateLimiter: + """Token bucket style async rate limiter.""" + + def __init__(self, max_calls: int, period: float) -> None: + self._max_calls = max_calls + self._period = period + self._calls = deque[float]() + self._lock = asyncio.Lock() + + async def acquire(self) -> None: + """Acquire permission to perform an operation respecting the rate limit.""" + + while True: + async with self._lock: + now = time.monotonic() + while self._calls and now - self._calls[0] > self._period: + self._calls.popleft() + + if len(self._calls) < self._max_calls: + self._calls.append(now) + return + + sleep_for = self._period - (now - self._calls[0]) + if sleep_for > self._period * 2: + raise RateLimitError("Rate limiter queue overflowed") + await asyncio.sleep(max(0.0, sleep_for)) diff --git a/sentiment/repository.py b/sentiment/repository.py new file mode 100644 index 0000000..eb3bfce --- /dev/null +++ b/sentiment/repository.py @@ -0,0 +1,352 @@ +"""Database repository with connection pooling and migrations.""" + +from __future__ import annotations + +import json +import logging +import sqlite3 +import threading +from collections.abc import Iterable, Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from queue import Queue +from typing import Any, List, Sequence +from urllib.parse import urlparse + +try: # pragma: no cover - optional dependency at runtime + from psycopg2.pool import SimpleConnectionPool # type: ignore +except Exception: # pragma: no cover - handle environments without psycopg2 + SimpleConnectionPool = None # type: ignore + +from .config import AppConfig +from .exceptions import DatabaseError +from .models import SentimentLabel, SentimentResult, SocialPost, SocialSource + + +class SQLiteConnectionPool: + """Simple connection pool for sqlite3 connections.""" + + def __init__(self, database: str, max_size: int) -> None: + self._database = database + self._max_size = max_size + self._lock = threading.Lock() + self._created = 0 + self._queue: "Queue[sqlite3.Connection]" = Queue() + + def getconn(self) -> sqlite3.Connection: + with self._lock: + if not self._queue.empty(): + return self._queue.get() + if self._created < self._max_size: + conn = sqlite3.connect( + self._database, + detect_types=sqlite3.PARSE_DECLTYPES, + check_same_thread=False, + ) + conn.row_factory = sqlite3.Row + self._created += 1 + return conn + return self._queue.get() + + def putconn(self, connection: sqlite3.Connection) -> None: + self._queue.put(connection) + + def closeall(self) -> None: + while not self._queue.empty(): + conn = self._queue.get() + conn.close() + + +@dataclass +class SchemaManager: + """Apply schema migrations sequentially.""" + + repository: "DataRepository" + target_version: int + + def apply_migrations(self) -> None: + current_version = self._get_current_version() + while current_version < self.target_version: + next_version = current_version + 1 + migration = getattr(self, f"_migration_v{next_version}", None) + if migration is None: + raise DatabaseError(f"Missing migration for version {next_version}") + migration() + self._record_version(next_version) + current_version = next_version + + def _get_current_version(self) -> int: + self.repository._ensure_migrations_table() + rows = self.repository._execute("SELECT MAX(version) as version FROM schema_migrations", fetch=True) + if not rows or rows[0]["version"] is None: + return 0 + return int(rows[0]["version"]) + + def _record_version(self, version: int) -> None: + placeholder = self.repository.placeholder + query = f"INSERT INTO schema_migrations (version, applied_at) VALUES ({placeholder}, {placeholder})" + now = datetime.now(timezone.utc).isoformat() + self.repository._execute(query, (version, now)) + + def _migration_v1(self) -> None: + driver = self.repository.driver + timestamp_type = "TIMESTAMPTZ" if driver == "postgres" else "TEXT" + metadata_type = "JSONB" if driver == "postgres" else "TEXT" + id_column = ( + "SERIAL PRIMARY KEY" + if driver == "postgres" + else "INTEGER PRIMARY KEY AUTOINCREMENT" + ) + + schema_statements = [ + f""" + CREATE TABLE IF NOT EXISTS social_posts ( + id {id_column}, + external_id TEXT UNIQUE NOT NULL, + text TEXT NOT NULL, + author TEXT NOT NULL, + source TEXT NOT NULL, + created_at {timestamp_type} NOT NULL, + url TEXT, + sentiment_label TEXT, + sentiment_score REAL, + metadata {metadata_type}, + inserted_at {timestamp_type} DEFAULT CURRENT_TIMESTAMP + ); + """, + ] + + for statement in schema_statements: + self.repository._execute(statement) + + +class DataRepository: + """Provides database access with pooling and schema management.""" + + def __init__(self, config: AppConfig, logger: logging.Logger | None = None) -> None: + self.config = config + self.logger = logger or logging.getLogger(self.__class__.__name__) + self.driver: str + self.placeholder: str + self._pool = self._create_pool() + self.schema_manager = SchemaManager(self, config.schema_version) + self.schema_manager.apply_migrations() + + def _create_pool(self) -> Any: + parsed = urlparse(self.config.database_url) + scheme = (parsed.scheme or "sqlite").lower() + if scheme.startswith("postgres"): + if SimpleConnectionPool is None: + raise DatabaseError("psycopg2 is required for PostgreSQL connections") + self.driver = "postgres" + self.placeholder = "%s" + dbname = (parsed.path or "/")[1:] + pool = SimpleConnectionPool( + self.config.min_pool_size, + self.config.max_pool_size, + user=parsed.username, + password=parsed.password, + host=parsed.hostname, + port=parsed.port or 5432, + database=dbname, + connect_timeout=int(self.config.request_timeout), + ) + return pool + + if scheme.startswith("sqlite"): + self.driver = "sqlite" + self.placeholder = "?" + path = parsed.path or parsed.netloc + if path.startswith("/"): + db_path = path + else: + db_path = f"./{path}" if path else "./sentiment.sqlite3" + Path(db_path).parent.mkdir(parents=True, exist_ok=True) + return SQLiteConnectionPool(db_path, self.config.max_pool_size) + + raise DatabaseError(f"Unsupported database scheme: {scheme}") + + @contextmanager + def get_connection(self) -> Iterator[Any]: + connection = self._acquire_connection() + try: + yield connection + connection.commit() + except sqlite3.Error as exc: # pragma: no cover - exercised indirectly + connection.rollback() + self.logger.error("SQLite error: %s", exc) + raise DatabaseError(str(exc)) from exc + except Exception as exc: # pragma: no cover - exercised indirectly + connection.rollback() + self.logger.error("Database error: %s", exc) + raise + finally: + self._release_connection(connection) + + def _acquire_connection(self) -> Any: + if self.driver == "postgres": + return self._pool.getconn() + return self._pool.getconn() + + def _release_connection(self, connection: Any) -> None: + if self.driver == "postgres": + self._pool.putconn(connection) + else: + self._pool.putconn(connection) + + def _ensure_migrations_table(self) -> None: + timestamp_type = "TIMESTAMPTZ" if self.driver == "postgres" else "TEXT" + statement = f""" + CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at {timestamp_type} NOT NULL + ); + """ + self._execute(statement) + + def insert_post(self, post: SocialPost) -> bool: + """Insert a post if it does not already exist.""" + + exists_query = ( + "SELECT 1 FROM social_posts WHERE external_id = " + f"{self.placeholder} LIMIT 1" + ) + params = (post.external_id,) + rows = self._execute(exists_query, params, fetch=True) + if rows: + return False + + insert_query = ( + "INSERT INTO social_posts (external_id, text, author, source, created_at, url, " + "sentiment_label, sentiment_score, metadata) " + f"VALUES ({', '.join([self.placeholder] * 9)})" + ) + try: + metadata_value = json.dumps(post.metadata) + except TypeError: + metadata_value = json.dumps({"raw": str(post.metadata)}) + sentiment_label = ( + post.sentiment.label.value if post.sentiment else None + ) + sentiment_score = ( + float(post.sentiment.score) if post.sentiment else None + ) + created_at_value = ( + post.created_at if self.driver == "postgres" else post.created_at.isoformat() + ) + insert_params = ( + post.external_id, + post.text, + post.author, + post.source.value, + created_at_value, + post.url, + sentiment_label, + sentiment_score, + metadata_value, + ) + self._execute(insert_query, insert_params) + return True + + def fetch_posts( + self, + *, + sentiment_filter: Sequence[str] | None = None, + limit: int | None = None, + sort_by: str = "created_at", + descending: bool = False, + ) -> List[SocialPost]: + """Fetch posts optionally filtered and sorted.""" + + allowed_sort = {"created_at", "author", "sentiment_score"} + if sort_by not in allowed_sort: + raise ValueError("Unsupported sort column") + + clauses: list[str] = [] + params: list[Any] = [] + if sentiment_filter: + placeholders = ", ".join([self.placeholder] * len(sentiment_filter)) + clauses.append(f"sentiment_label IN ({placeholders})") + params.extend(sentiment_filter) + + where_sql = f" WHERE {' AND '.join(clauses)}" if clauses else "" + order_sql = f" ORDER BY {sort_by} {'DESC' if descending else 'ASC'}" + limit_sql = f" LIMIT {limit}" if limit else "" + + query = ( + "SELECT external_id, text, author, source, created_at, url, sentiment_label, sentiment_score, metadata " + "FROM social_posts" + f"{where_sql}{order_sql}{limit_sql}" + ) + rows = self._execute(query, tuple(params), fetch=True) or [] + return [self._row_to_post(row) for row in rows] + + def export(self, exporter, *, sentiment_filter: Sequence[str] | None = None, filename: str = "posts") -> None: + posts = self.fetch_posts(sentiment_filter=sentiment_filter) + exporter.export(posts, filename=filename) + + def close(self) -> None: + """Close all pooled connections.""" + + if self.driver == "postgres": # pragma: no cover - requires postgres runtime + self._pool.closeall() + else: + self._pool.closeall() + + # Internal helpers ------------------------------------------------- + + def _execute( + self, + query: str, + params: Sequence[Any] | None = None, + *, + fetch: bool = False, + ) -> List[dict[str, Any]] | None: + params = params or [] + with self.get_connection() as connection: + cursor = connection.cursor() + try: + cursor.execute(query, params) + if fetch: + columns = [desc[0] for desc in cursor.description] # type: ignore[assignment] + rows = [dict(zip(columns, row)) for row in cursor.fetchall()] + return rows + finally: + cursor.close() + return None + + def _row_to_post(self, row: dict[str, Any]) -> SocialPost: + created_at_raw = row["created_at"] + if isinstance(created_at_raw, datetime): + created_at = created_at_raw + else: + created_at = datetime.fromisoformat(str(created_at_raw)) + + metadata_value = row.get("metadata") + if isinstance(metadata_value, str): + try: + metadata = json.loads(metadata_value) + except json.JSONDecodeError: + metadata = {"raw": metadata_value} + else: + metadata = metadata_value or {} + + sentiment = None + if row.get("sentiment_label"): + sentiment = SentimentResult( + score=float(row.get("sentiment_score") or 0.0), + label=SentimentLabel(row["sentiment_label"]), + ) + + return SocialPost( + external_id=row["external_id"], + text=row["text"], + author=row["author"], + source=SocialSource(row["source"]), + created_at=created_at, + url=row.get("url"), + metadata=metadata, + sentiment=sentiment, + ) diff --git a/sentiment/retry.py b/sentiment/retry.py new file mode 100644 index 0000000..4b65380 --- /dev/null +++ b/sentiment/retry.py @@ -0,0 +1,42 @@ +"""Async retry helpers.""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Awaitable, Callable +from typing import Tuple, TypeVar + +T = TypeVar("T") + + +async def execute_with_retry( + operation: Callable[[], Awaitable[T]], + *, + max_attempts: int, + backoff_factor: float, + retry_exceptions: Tuple[type[BaseException], ...], + logger: logging.Logger, + base_delay: float = 1.0, +) -> T: + """Execute an async callable with exponential backoff.""" + + attempt = 0 + while True: + try: + return await operation() + except retry_exceptions as exc: # type: ignore[arg-type] + attempt += 1 + if attempt >= max_attempts: + logger.error("Operation failed after %s attempts", attempt, exc_info=exc) + raise + + delay = base_delay * (backoff_factor ** (attempt - 1)) + logger.warning( + "Operation failed with %s. Retrying in %.2fs (attempt %s/%s)", + exc, + delay, + attempt + 1, + max_attempts, + ) + await asyncio.sleep(delay) diff --git a/sentiment/sanitizer.py b/sentiment/sanitizer.py new file mode 100644 index 0000000..1fc8443 --- /dev/null +++ b/sentiment/sanitizer.py @@ -0,0 +1,59 @@ +"""Utilities for validating and sanitizing incoming data.""" + +from __future__ import annotations + +import re +from datetime import datetime +from typing import Iterable + +from zoneinfo import ZoneInfo + +from .exceptions import ValidationError +from .models import SocialPost + +_TEXT_FILTER = re.compile(r"[^\w\s.,!?/:;-]") +_WHITESPACE_RE = re.compile(r"\s+") + + +def sanitize_text(text: str) -> str: + """Normalize whitespace and strip unwanted characters.""" + + normalized = _WHITESPACE_RE.sub(" ", (text or "").strip()) + return _TEXT_FILTER.sub("", normalized) + + +def ensure_timezone(timestamp: datetime, timezone: str) -> datetime: + """Return the timestamp converted to the configured timezone.""" + + zone = ZoneInfo(timezone) + if timestamp.tzinfo is None: + return timestamp.replace(tzinfo=zone) + return timestamp.astimezone(zone) + + +def validate_post(post: SocialPost, timezone: str) -> SocialPost: + """Validate mandatory fields and normalize timestamps.""" + + if not post.external_id: + raise ValidationError("Post is missing an external identifier") + + sanitized_text = sanitize_text(post.text) + if not sanitized_text: + raise ValidationError("Post does not contain analyzable text") + + post.text = sanitized_text + post.created_at = ensure_timezone(post.created_at, timezone) + return post + + +def chunk_items(items: Iterable[SocialPost], chunk_size: int) -> Iterable[list[SocialPost]]: + """Yield posts in batches for easier processing.""" + + chunk: list[SocialPost] = [] + for item in items: + chunk.append(item) + if len(chunk) == chunk_size: + yield chunk + chunk = [] + if chunk: + yield chunk diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py new file mode 100644 index 0000000..e2db719 --- /dev/null +++ b/tests/test_analyzer.py @@ -0,0 +1,39 @@ +"""Tests for the sentiment analyzer service.""" + +from __future__ import annotations + +import unittest + +from sentiment.analyzer import SentimentAnalyzer +from sentiment.config import AppConfig +from sentiment.models import SentimentLabel + + +class SentimentAnalyzerTests(unittest.TestCase): + """Validate sentiment scoring.""" + + def setUp(self) -> None: + self.config = AppConfig( + database_url="sqlite:///tmp/test.db", + twitter_bearer_token="token", + reddit_client_id="client", + reddit_client_secret="secret", + sources=[], + ) + self.analyzer = SentimentAnalyzer(self.config) + + def test_positive_sentiment(self) -> None: + result = self.analyzer.analyze("This is a great and awesome day") + self.assertEqual(result.label, SentimentLabel.POSITIVE) + + def test_negative_sentiment(self) -> None: + result = self.analyzer.analyze("This is a bad and terrible idea") + self.assertEqual(result.label, SentimentLabel.NEGATIVE) + + def test_neutral_sentiment(self) -> None: + result = self.analyzer.analyze("The sky is blue") + self.assertEqual(result.label, SentimentLabel.NEUTRAL) + + +if __name__ == "__main__": # pragma: no cover + unittest.main() diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..6532590 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,43 @@ +"""Tests for configuration management.""" + +from __future__ import annotations + +import os +import unittest +from unittest import mock + +from sentiment.config import AppConfig +from sentiment.exceptions import ConfigurationError + + +class AppConfigTests(unittest.TestCase): + """Validate AppConfig loading and validation.""" + + def test_from_env_missing_database_url(self) -> None: + with mock.patch.dict(os.environ, { + "TWITTER_BEARER_TOKEN": "token", + "REDDIT_CLIENT_ID": "client", + "REDDIT_CLIENT_SECRET": "secret", + }, clear=True): + with self.assertRaises(ConfigurationError): + AppConfig.from_env() + + def test_from_env_success(self) -> None: + with mock.patch.dict(os.environ, { + "DATABASE_URL": "sqlite:///tmp/test.db", + "TWITTER_BEARER_TOKEN": "token", + "REDDIT_CLIENT_ID": "client", + "REDDIT_CLIENT_SECRET": "secret", + "SENTIMENT_SOURCES": "twitter,reddit", + "REQUEST_TIMEOUT": "5", + "BATCH_SIZE": "10", + }, clear=True): + config = AppConfig.from_env() + self.assertEqual(config.database_url, "sqlite:///tmp/test.db") + self.assertEqual(config.batch_size, 10) + self.assertIn("twitter", config.enabled_sources) + self.assertIn("reddit", config.enabled_sources) + + +if __name__ == "__main__": # pragma: no cover + unittest.main() diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..1537638 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,89 @@ +"""Pipeline integration tests.""" + +from __future__ import annotations + +import tempfile +import unittest +from datetime import datetime, timezone +from pathlib import Path + +from sentiment.analyzer import SentimentAnalyzer +from sentiment.config import AppConfig +from sentiment.exporter import DataExporter +from sentiment.metrics import MetricsTracker +from sentiment.models import SocialPost, SocialSource +from sentiment.pipeline import SentimentPipeline +from sentiment.repository import DataRepository +from sentiment.fetchers.base import BaseFetcher + + +class DummyFetcher(BaseFetcher): + """Test double returning pre-defined posts.""" + + def __init__(self, config: AppConfig, posts: list[SocialPost]): + super().__init__(config) + self._posts = posts + + async def fetch(self, query: str, limit: int): + return self._posts[:limit] + + +class PipelineTests(unittest.IsolatedAsyncioTestCase): + """Validate the pipeline end-to-end using sqlite.""" + + async def asyncSetUp(self) -> None: + self.tmp_dir = tempfile.TemporaryDirectory() + self.export_dir = Path(self.tmp_dir.name) / "exports" + db_path = f"sqlite:///{self.tmp_dir.name}/pipeline.sqlite3" + self.config = AppConfig( + database_url=db_path, + twitter_bearer_token="token", + reddit_client_id="client", + reddit_client_secret="secret", + sources=[], + batch_size=2, + ) + self.repository = DataRepository(self.config) + self.analyzer = SentimentAnalyzer(self.config) + self.exporter = DataExporter(self.export_dir) + self.metrics = MetricsTracker() + now = datetime.now(tz=timezone.utc) + posts = [ + SocialPost( + external_id="1", + text="I love async pipelines", + author="alice", + source=SocialSource.TWITTER, + created_at=now, + ), + SocialPost( + external_id="2", + text="I hate slow databases", + author="bob", + source=SocialSource.REDDIT, + created_at=now, + ), + ] + self.fetcher = DummyFetcher(self.config, posts) + self.pipeline = SentimentPipeline( + config=self.config, + repository=self.repository, + analyzer=self.analyzer, + fetchers=[self.fetcher], + exporter=self.exporter, + metrics=self.metrics, + ) + + async def asyncTearDown(self) -> None: + self.repository.close() + self.tmp_dir.cleanup() + + async def test_pipeline_runs_and_persists_posts(self) -> None: + result = await self.pipeline.run(query="ai", limit=10, export=False) + self.assertGreaterEqual(result.metrics["posts_processed"], 1) + posts = self.repository.fetch_posts() + self.assertEqual(len(posts), 2) + + +if __name__ == "__main__": # pragma: no cover + unittest.main() diff --git a/tests/test_repository.py b/tests/test_repository.py new file mode 100644 index 0000000..3d94aba --- /dev/null +++ b/tests/test_repository.py @@ -0,0 +1,63 @@ +"""Tests for the data repository.""" + +from __future__ import annotations + +import tempfile +import unittest +from datetime import datetime, timezone + +from sentiment.config import AppConfig +from sentiment.models import SentimentLabel, SentimentResult, SocialPost, SocialSource +from sentiment.repository import DataRepository + + +class DataRepositoryTests(unittest.TestCase): + """Validate repository behaviour using sqlite.""" + + def setUp(self) -> None: + self._tmp_dir = tempfile.TemporaryDirectory() + db_path = f"sqlite:///{self._tmp_dir.name}/test.sqlite3" + self.config = AppConfig( + database_url=db_path, + twitter_bearer_token="token", + reddit_client_id="client", + reddit_client_secret="secret", + sources=[], + ) + self.repository = DataRepository(self.config) + + def tearDown(self) -> None: + self.repository.close() + self._tmp_dir.cleanup() + + def _build_post(self, external_id: str, sentiment_label: SentimentLabel) -> SocialPost: + post = SocialPost( + external_id=external_id, + text="I love testing" if sentiment_label == SentimentLabel.POSITIVE else "I hate bugs", + author="tester", + source=SocialSource.TWITTER, + created_at=datetime.now(tz=timezone.utc), + ) + post.sentiment = SentimentResult(score=0.9 if sentiment_label == SentimentLabel.POSITIVE else -0.9, label=sentiment_label) + return post + + def test_insert_and_fetch_posts(self) -> None: + post = self._build_post("abc", SentimentLabel.POSITIVE) + self.assertTrue(self.repository.insert_post(post)) + self.assertFalse(self.repository.insert_post(post)) # duplicate + posts = self.repository.fetch_posts() + self.assertEqual(len(posts), 1) + self.assertEqual(posts[0].external_id, "abc") + + def test_fetch_with_sentiment_filter(self) -> None: + positive = self._build_post("p", SentimentLabel.POSITIVE) + negative = self._build_post("n", SentimentLabel.NEGATIVE) + self.repository.insert_post(positive) + self.repository.insert_post(negative) + filtered = self.repository.fetch_posts(sentiment_filter=[SentimentLabel.POSITIVE.value]) + self.assertEqual(len(filtered), 1) + self.assertEqual(filtered[0].external_id, "p") + + +if __name__ == "__main__": # pragma: no cover + unittest.main() diff --git a/tests/test_sanitizer.py b/tests/test_sanitizer.py new file mode 100644 index 0000000..d53a7d0 --- /dev/null +++ b/tests/test_sanitizer.py @@ -0,0 +1,33 @@ +"""Tests for sanitizer utilities.""" + +from __future__ import annotations + +import unittest +from datetime import datetime + +from sentiment.models import SocialPost, SocialSource +from sentiment.sanitizer import sanitize_text, validate_post + + +class SanitizerTests(unittest.TestCase): + """Ensure sanitizer works as expected.""" + + def test_sanitize_text(self) -> None: + text = " Hello\tWorld!!@@## " + self.assertEqual(sanitize_text(text), "Hello World!!") + + def test_validate_post_normalizes_timezone(self) -> None: + post = SocialPost( + external_id="1", + text="Some text!!!", + author="alice", + source=SocialSource.TWITTER, + created_at=datetime(2024, 1, 1, 12, 0, 0), + ) + validated = validate_post(post, "UTC") + self.assertIsNotNone(validated.created_at.tzinfo) + self.assertEqual(validated.text, "Some text!!!") + + +if __name__ == "__main__": # pragma: no cover + unittest.main()