diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a4f0157 --- /dev/null +++ b/.gitignore @@ -0,0 +1,20 @@ +__pycache__/ +*.py[cod] +*.log +*.sqlite3 +*.db +.env +.env.* +.venv/ +.idea/ +.vscode/ +.mypy_cache/ +.pytest_cache/ +coverage.xml +htmlcov/ +exports/ +metrics.json +dist/ +build/ +*.egg-info/ +db_utils/__pycache__/ diff --git a/README.md b/README.md index 7b60ffb..eb918cb 100644 --- a/README.md +++ b/README.md @@ -1 +1,86 @@ -# Citadel \ No newline at end of file +# Sentiment Analysis Pipeline + +## Overview + +This repository now ships a production-ready sentiment analysis pipeline that ingests +posts from social sources (Twitter and Reddit), scores their sentiment, persists the +results, exports structured data, and records runtime metrics. The pipeline was +refactored to emphasise reliability, observability, and maintainability. + +## Key Features + +- **Centralised configuration** via `AppConfig` with strict validation of required + environment variables, configurable timeouts, retry counts, and batching controls. +- **Structured logging** across the stack with DEBUG/INFO/WARNING/ERROR semantics. +- **Robust API access** featuring rate limiting, exponential backoff retries, + response validation, and sanitisation of fetched payloads. +- **Async orchestration** that fetches from multiple sources concurrently while + respecting rate limits and tracking batch progress. +- **Modular architecture** with dedicated components for fetchers, analysis, + persistence, exporting, and metrics. +- **Database reliability** through a repository that uses connection pooling, + schema migrations, context-managed operations, and duplicate detection. +- **Data hygiene** by sanitising text, validating timestamps, filtering duplicates, + and converting everything to a consistent timezone. +- **Operational insights** via metrics for processed posts, duplicates, failures, + and per-sentiment distribution alongside optional JSON/CSV exports. + +## Environment Variables + +| Variable | Description | +| --- | --- | +| `DATABASE_URL` | Database connection string. Supports `sqlite:///path/to.db` or `postgres://user:pass@host:port/db`. | +| `TWITTER_BEARER_TOKEN` | Required when enabling the Twitter fetcher. | +| `REDDIT_CLIENT_ID` / `REDDIT_CLIENT_SECRET` | Required when enabling the Reddit fetcher. | +| `SENTIMENT_SOURCES` | Comma-separated list of fetchers to enable (`twitter,reddit`). Defaults to both. | +| `REQUEST_TIMEOUT` | HTTP request timeout in seconds (default `10`). | +| `MAX_RETRIES` | Number of API retry attempts (default `3`). | +| `RETRY_BACKOFF_FACTOR` | Exponential backoff multiplier (default `2`). | +| `BATCH_SIZE` | Number of posts processed concurrently in each batch (default `25`). | +| `RATE_LIMIT_PER_MINUTE` | Maximum API calls per minute (default `60`). | +| `EXPORT_PATH` | Directory for exported datasets (default `exports`). | +| `METRICS_EXPORT_PATH` | Output path for metrics JSON (default `metrics.json`). | + +## Running the Pipeline + +1. Ensure Python 3.11+ is available and install dependencies (e.g. `pip install -r requirements.txt`). +2. Export the required environment variables shown above. +3. Launch the CLI: + +```bash +python -m sentiment.cli --query "ai" --limit 50 --export --log-level INFO +``` + +The CLI validates configuration, spins up the async pipeline, streams logs, and writes +metrics/exports on completion. + +## Testing + +Unit tests cover configuration, sanitisation, the sentiment analyser, repository, +and pipeline orchestration. Execute them with: + +```bash +python -m unittest discover tests +``` + +## Project Structure + +``` +sentiment/ + analyzer.py # Sentiment scoring service + cli.py # Command-line entry point + config.py # Dataclass-backed configuration + exporter.py # JSON/CSV export helpers + fetchers/ # Twitter & Reddit API clients derived from BaseFetcher + metrics.py # Metrics tracking utilities + pipeline.py # Async orchestrator + repository.py # DB access layer with pooling + migrations + sanitizer.py # Text/time validation & sanitisation helpers +``` + +## Notes + +- Postgres deployments require `psycopg2`. SQLite is supported out-of-the-box for + local testing and is used inside the automated unit tests. +- Metrics are exported after each pipeline run and can be ingested by dashboards or + alerting tooling as needed. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..83ec125 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +psycopg2-binary>=2.9 diff --git a/sentiment/__init__.py b/sentiment/__init__.py new file mode 100644 index 0000000..338beec --- /dev/null +++ b/sentiment/__init__.py @@ -0,0 +1,6 @@ +"""High level sentiment analysis toolkit.""" + +from .config import AppConfig # noqa: F401 +from .analyzer import SentimentAnalyzer # noqa: F401 +from .repository import DataRepository # noqa: F401 +from .pipeline import SentimentPipeline # noqa: F401 diff --git a/sentiment/analyzer.py b/sentiment/analyzer.py new file mode 100644 index 0000000..22956ad --- /dev/null +++ b/sentiment/analyzer.py @@ -0,0 +1,57 @@ +"""Lightweight sentiment analyzer service.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Iterable, Set + +from .config import AppConfig +from .models import SentimentLabel, SentimentResult + + +@dataclass +class SentimentAnalyzer: + """Performs rule-based sentiment analysis on sanitized text.""" + + config: AppConfig + positive_words: Set[str] = field(default_factory=lambda: { + "great", + "awesome", + "good", + "love", + "happy", + "fantastic", + "win", + "excellent", + }) + negative_words: Set[str] = field(default_factory=lambda: { + "bad", + "terrible", + "hate", + "sad", + "awful", + "fail", + "poor", + "angry", + }) + + def analyze(self, text: str) -> SentimentResult: + """Return the sentiment label for the given text.""" + + tokens = [token.lower() for token in text.split() if token] + score = 0 + for token in tokens: + if token in self.positive_words: + score += 1 + elif token in self.negative_words: + score -= 1 + + normalized_score = score / max(len(tokens), 1) + if normalized_score >= self.config.sentiment_positive_threshold: + label = SentimentLabel.POSITIVE + elif normalized_score <= self.config.sentiment_negative_threshold: + label = SentimentLabel.NEGATIVE + else: + label = SentimentLabel.NEUTRAL + + return SentimentResult(score=normalized_score, label=label) diff --git a/sentiment/cli.py b/sentiment/cli.py new file mode 100644 index 0000000..7811157 --- /dev/null +++ b/sentiment/cli.py @@ -0,0 +1,122 @@ +"""Command line interface for the sentiment pipeline.""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import sys +from typing import List + +from .analyzer import SentimentAnalyzer +from .config import AppConfig +from .exceptions import ConfigurationError +from .exporter import DataExporter +from .fetchers.reddit import RedditFetcher +from .fetchers.twitter import TwitterFetcher +from .logging_utils import configure_logging +from .metrics import MetricsTracker +from .pipeline import SentimentPipeline +from .repository import DataRepository + + +async def async_main(args: argparse.Namespace) -> int: + """Run the async pipeline entry point.""" + + try: + config = AppConfig.from_env() + except ConfigurationError as exc: + logging.getLogger("sentiment").error("Configuration error: %s", exc) + return 1 + + configure_logging(args.log_level) + logger = logging.getLogger("sentiment.cli") + + fetchers: List = [] + if "twitter" in config.enabled_sources: + try: + fetchers.append(TwitterFetcher(config)) + except Exception as exc: + logger.error("Twitter fetcher initialization failed: %s", exc) + if "reddit" in config.enabled_sources: + try: + fetchers.append(RedditFetcher(config)) + except Exception as exc: + logger.error("Reddit fetcher initialization failed: %s", exc) + + if not fetchers: + logger.error("No fetchers available. Ensure credentials are configured correctly.") + return 1 + + repository = DataRepository(config) + analyzer = SentimentAnalyzer(config) + exporter = DataExporter(config.export_path) + metrics = MetricsTracker() + + pipeline = SentimentPipeline( + config=config, + repository=repository, + analyzer=analyzer, + fetchers=fetchers, + exporter=exporter, + metrics=metrics, + ) + + try: + result = await pipeline.run( + query=args.query, + limit=args.limit, + sentiment_filter=args.sentiment, + sort_by=args.sort_by, + descending=args.descending, + export=args.export, + ) + logger.info( + "Processed %s posts (duplicates: %s)", + result.metrics["posts_processed"], + result.metrics["duplicates_skipped"], + ) + return 0 + finally: + repository.close() + + +def build_parser() -> argparse.ArgumentParser: + """Create an argument parser for the CLI.""" + + parser = argparse.ArgumentParser(description="Social sentiment pipeline") + parser.add_argument("--query", default="ai", help="Search query") + parser.add_argument("--limit", type=int, default=25, help="Max items per fetcher") + parser.add_argument( + "--sentiment", + nargs="*", + choices=["positive", "negative", "neutral"], + help="Filter output posts by sentiment", + ) + parser.add_argument( + "--sort-by", + choices=["created_at", "author", "sentiment_score"], + default="created_at", + help="Sort column", + ) + parser.add_argument("--descending", action="store_true", help="Sort descending") + parser.add_argument("--export", action="store_true", help="Export processed posts") + parser.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Log verbosity", + ) + return parser + + +def main(argv: List[str] | None = None) -> int: + """Synchronous CLI entry point.""" + + parser = build_parser() + args = parser.parse_args(argv) + return asyncio.run(async_main(args)) + + +if __name__ == "__main__": # pragma: no cover - manual execution + sys.exit(main()) diff --git a/sentiment/config.py b/sentiment/config.py new file mode 100644 index 0000000..c33caf9 --- /dev/null +++ b/sentiment/config.py @@ -0,0 +1,129 @@ +"""Application configuration management.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from typing import List + +from .exceptions import ConfigurationError + + +@dataclass(slots=True) +class AppConfig: + """Runtime configuration for the sentiment pipeline.""" + + database_url: str + twitter_bearer_token: str | None = None + reddit_client_id: str | None = None + reddit_client_secret: str | None = None + request_timeout: float = 10.0 + fetch_timeout: float = 15.0 + max_retries: int = 3 + retry_backoff_factor: float = 2.0 + rate_limit_per_minute: int = 60 + rate_limit_period: float = 60.0 + batch_size: int = 25 + concurrency: int = 5 + export_path: str = "exports" + export_format: str = "json" + metrics_export_path: str = "metrics.json" + min_pool_size: int = 1 + max_pool_size: int = 5 + schema_version: int = 1 + timezone: str = "UTC" + duplicate_window_minutes: int = 120 + sentiment_positive_threshold: float = 0.2 + sentiment_negative_threshold: float = -0.2 + sources: List[str] = field(default_factory=lambda: ["twitter", "reddit"]) + + def __post_init__(self) -> None: + """Normalize configured sources.""" + + self.sources = [source.strip().lower() for source in self.sources if source.strip()] + + @classmethod + def from_env(cls) -> "AppConfig": + """Build the configuration from environment variables.""" + + env = os.environ + sources = env.get("SENTIMENT_SOURCES") + parsed_sources = ( + [src.strip().lower() for src in sources.split(",") if src.strip()] + if sources + else ["twitter", "reddit"] + ) + + config = cls( + database_url=env.get("DATABASE_URL", "").strip(), + twitter_bearer_token=env.get("TWITTER_BEARER_TOKEN"), + reddit_client_id=env.get("REDDIT_CLIENT_ID"), + reddit_client_secret=env.get("REDDIT_CLIENT_SECRET"), + request_timeout=float(env.get("REQUEST_TIMEOUT", 10.0)), + fetch_timeout=float(env.get("FETCH_TIMEOUT", 15.0)), + max_retries=int(env.get("MAX_RETRIES", 3)), + retry_backoff_factor=float(env.get("RETRY_BACKOFF_FACTOR", 2.0)), + rate_limit_per_minute=int(env.get("RATE_LIMIT_PER_MINUTE", 60)), + rate_limit_period=float(env.get("RATE_LIMIT_PERIOD", 60.0)), + batch_size=int(env.get("BATCH_SIZE", 25)), + concurrency=int(env.get("CONCURRENCY", 5)), + export_path=env.get("EXPORT_PATH", "exports"), + export_format=env.get("EXPORT_FORMAT", "json"), + metrics_export_path=env.get("METRICS_EXPORT_PATH", "metrics.json"), + min_pool_size=int(env.get("DB_POOL_MIN", 1)), + max_pool_size=int(env.get("DB_POOL_MAX", 5)), + schema_version=int(env.get("SCHEMA_VERSION", 1)), + timezone=env.get("TIMEZONE", "UTC"), + duplicate_window_minutes=int(env.get("DUPLICATE_WINDOW_MINUTES", 120)), + sentiment_positive_threshold=float(env.get("SENTIMENT_POS_THRESHOLD", 0.2)), + sentiment_negative_threshold=float(env.get("SENTIMENT_NEG_THRESHOLD", -0.2)), + sources=parsed_sources, + ) + config.validate() + return config + + @property + def enabled_sources(self) -> List[str]: + """Return sources that can be instantiated with the available secrets.""" + + enabled: list[str] = [] + for source in self.sources: + if source == "twitter" and self.twitter_bearer_token: + enabled.append(source) + elif ( + source == "reddit" + and self.reddit_client_id + and self.reddit_client_secret + ): + enabled.append(source) + return enabled + + def validate(self) -> None: + """Validate that the configuration satisfies all prerequisites.""" + + missing: list[str] = [] + if not self.database_url: + missing.append("DATABASE_URL") + + normalized_sources = {source.strip().lower() for source in self.sources} + if "twitter" in normalized_sources and not self.twitter_bearer_token: + missing.append("TWITTER_BEARER_TOKEN") + if "reddit" in normalized_sources: + if not self.reddit_client_id: + missing.append("REDDIT_CLIENT_ID") + if not self.reddit_client_secret: + missing.append("REDDIT_CLIENT_SECRET") + + if missing: + raise ConfigurationError( + "Missing required environment variables: " + ", ".join(sorted(set(missing))) + ) + + if self.max_pool_size < self.min_pool_size: + raise ConfigurationError("DB_POOL_MAX must be >= DB_POOL_MIN") + + if self.batch_size <= 0: + raise ConfigurationError("BATCH_SIZE must be a positive integer") + + if not normalized_sources: + raise ConfigurationError("At least one data source must be specified") diff --git a/sentiment/exceptions.py b/sentiment/exceptions.py new file mode 100644 index 0000000..5b1b46e --- /dev/null +++ b/sentiment/exceptions.py @@ -0,0 +1,31 @@ +"""Custom exceptions for the sentiment analysis pipeline.""" + +from __future__ import annotations + + +class SentimentError(Exception): + """Base exception for sentiment related errors.""" + + +class ConfigurationError(SentimentError): + """Raised when the application configuration is invalid.""" + + +class APIFetchError(SentimentError): + """Raised when an upstream API call fails.""" + + +class DatabaseError(SentimentError): + """Raised when database operations fail.""" + + +class DuplicatePostError(DatabaseError): + """Raised when attempting to insert an already persisted post.""" + + +class RateLimitError(SentimentError): + """Raised when the rate limiter cannot provide capacity in time.""" + + +class ValidationError(SentimentError): + """Raised when data validation fails.""" diff --git a/sentiment/exporter.py b/sentiment/exporter.py new file mode 100644 index 0000000..3dbac67 --- /dev/null +++ b/sentiment/exporter.py @@ -0,0 +1,63 @@ +"""Data export helpers.""" + +from __future__ import annotations + +import csv +import json +from pathlib import Path +from typing import Iterable, List + +from .models import SocialPost + + +class DataExporter: + """Exports processed posts to disk.""" + + def __init__(self, output_dir: str | Path) -> None: + self._output_dir = Path(output_dir) + self._output_dir.mkdir(parents=True, exist_ok=True) + + def export(self, posts: Iterable[SocialPost], *, filename: str, fmt: str = "json") -> Path: + """Export posts to the requested format.""" + + if fmt not in {"json", "csv"}: + raise ValueError("Unsupported export format") + + path = self._output_dir / f"{filename}.{fmt}" + if fmt == "json": + self._export_json(posts, path) + else: + self._export_csv(posts, path) + return path + + def _export_json(self, posts: Iterable[SocialPost], path: Path) -> None: + serialized = [self._serialize_post(post) for post in posts] + with path.open("w", encoding="utf-8") as handle: + json.dump(serialized, handle, indent=2) + + def _export_csv(self, posts: Iterable[SocialPost], path: Path) -> None: + serialized = [self._serialize_post(post) for post in posts] + if not serialized: + path.touch() + return + + with path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=list(serialized[0].keys())) + writer.writeheader() + writer.writerows(serialized) + + @staticmethod + def _serialize_post(post: SocialPost) -> dict: + result = { + "external_id": post.external_id, + "author": post.author, + "source": post.source.value, + "created_at": post.created_at.isoformat(), + "text": post.text, + "url": post.url, + "metadata": post.metadata, + } + if post.sentiment: + result["sentiment_score"] = post.sentiment.score + result["sentiment_label"] = post.sentiment.label.value + return result diff --git a/sentiment/fetchers/__init__.py b/sentiment/fetchers/__init__.py new file mode 100644 index 0000000..0392ce0 --- /dev/null +++ b/sentiment/fetchers/__init__.py @@ -0,0 +1 @@ +"""Fetcher implementations.""" diff --git a/sentiment/fetchers/base.py b/sentiment/fetchers/base.py new file mode 100644 index 0000000..cffbe0c --- /dev/null +++ b/sentiment/fetchers/base.py @@ -0,0 +1,90 @@ +"""Base fetcher definition.""" + +from __future__ import annotations + +import asyncio +import json +import logging +import urllib.error +import urllib.parse +import urllib.request +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, List, Mapping, Optional + +from ..config import AppConfig +from ..exceptions import APIFetchError +from ..models import SocialPost +from ..rate_limiter import AsyncRateLimiter +from ..retry import execute_with_retry + + +class BaseFetcher(ABC): + """Common functionality for remote data fetchers.""" + + def __init__(self, config: AppConfig, logger: logging.Logger | None = None) -> None: + self.config = config + self.logger = logger or logging.getLogger(self.__class__.__name__) + self.rate_limiter = AsyncRateLimiter( + config.rate_limit_per_minute, + config.rate_limit_period, + ) + + @abstractmethod + async def fetch(self, query: str, limit: int) -> List[SocialPost]: + """Fetch posts for the given query.""" + + async def _request(self, url: str, params: Mapping[str, Any], headers: Mapping[str, str]) -> Dict[str, Any]: + """Execute an HTTP GET request with retry and rate limiting.""" + + await self.rate_limiter.acquire() + + async def _call() -> Dict[str, Any]: + return await asyncio.to_thread(self._perform_request, url, params, headers) + + return await execute_with_retry( + _call, + max_attempts=self.config.max_retries, + backoff_factor=self.config.retry_backoff_factor, + retry_exceptions=(APIFetchError, urllib.error.URLError, urllib.error.HTTPError, ConnectionError), + logger=self.logger, + base_delay=1.0, + ) + + def _perform_request( + self, + url: str, + params: Mapping[str, Any], + headers: Mapping[str, str], + ) -> Dict[str, Any]: + """Perform an HTTP GET request synchronously.""" + + try: + query = urllib.parse.urlencode({k: v for k, v in params.items() if v is not None}) + delimiter = "?" if "?" not in url else "&" + target = f"{url}{delimiter}{query}" if query else url + request = urllib.request.Request(target, headers=headers) + with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response: + payload = response.read().decode("utf-8") + except urllib.error.HTTPError as exc: # pragma: no cover - network errors + raise APIFetchError(f"HTTP error {exc.code}: {exc.reason}") from exc + except urllib.error.URLError as exc: # pragma: no cover - network errors + raise APIFetchError(f"Network error: {exc.reason}") from exc + + try: + return json.loads(payload) + except json.JSONDecodeError as exc: + raise APIFetchError("API response was not valid JSON") from exc + + @staticmethod + def _parse_timestamp(value: str | None) -> Optional[datetime]: + if not value: + return None + normalized = value.replace("Z", "+00:00") + return datetime.fromisoformat(normalized) + + @staticmethod + def _validate_keys(payload: Mapping[str, Any], required_keys: List[str]) -> None: + missing = [key for key in required_keys if key not in payload] + if missing: + raise APIFetchError(f"API response missing keys: {', '.join(missing)}") diff --git a/sentiment/fetchers/reddit.py b/sentiment/fetchers/reddit.py new file mode 100644 index 0000000..5539c61 --- /dev/null +++ b/sentiment/fetchers/reddit.py @@ -0,0 +1,58 @@ +"""Reddit fetcher implementation.""" + +from __future__ import annotations + +import base64 +from datetime import datetime, timezone +from typing import Any, Dict, List + +from ..config import AppConfig +from ..exceptions import APIFetchError +from ..models import SocialPost, SocialSource +from ..sanitizer import sanitize_text +from .base import BaseFetcher + + +class RedditFetcher(BaseFetcher): + """Fetch posts from Reddit's public search API.""" + + API_URL = "https://www.reddit.com/search.json" + + def __init__(self, config: AppConfig) -> None: + super().__init__(config) + if not config.reddit_client_id or not config.reddit_client_secret: + raise APIFetchError("Reddit credentials are missing") + + async def fetch(self, query: str, limit: int) -> List[SocialPost]: + params = { + "q": query, + "limit": min(max(limit, 10), 100), + "sort": "new", + "restrict_sr": False, + } + headers = { + "User-Agent": "sentiment-analyzer/1.0", + "Authorization": f"Basic {self._encode_credentials()}", + } + response = await self._request(self.API_URL, params, headers) + self._validate_keys(response, ["data"]) + children = response["data"].get("children", []) + return [self._build_post(child.get("data", {})) for child in children] + + def _encode_credentials(self) -> str: + token = f"{self.config.reddit_client_id}:{self.config.reddit_client_secret}" + encoded = base64.b64encode(token.encode("utf-8")).decode("ascii") + return encoded + + def _build_post(self, payload: Dict[str, Any]) -> SocialPost: + created_utc = payload.get("created_utc") + created_at = datetime.fromtimestamp(created_utc, tz=timezone.utc) if created_utc else datetime.now(tz=timezone.utc) + return SocialPost( + external_id=str(payload.get("id")), + text=sanitize_text(payload.get("selftext") or payload.get("title", "")), + author=str(payload.get("author", "unknown")), + source=SocialSource.REDDIT, + created_at=created_at, + url=f"https://www.reddit.com{payload.get('permalink', '')}", + metadata={"raw": payload}, + ) diff --git a/sentiment/fetchers/twitter.py b/sentiment/fetchers/twitter.py new file mode 100644 index 0000000..ca1c443 --- /dev/null +++ b/sentiment/fetchers/twitter.py @@ -0,0 +1,49 @@ +"""Twitter API fetcher.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, List + +from ..config import AppConfig +from ..exceptions import APIFetchError +from ..models import SocialPost, SocialSource +from ..sanitizer import sanitize_text +from .base import BaseFetcher + + +class TwitterFetcher(BaseFetcher): + """Fetches tweets using the Twitter v2 API.""" + + API_URL = "https://api.twitter.com/2/tweets/search/recent" + + def __init__(self, config: AppConfig) -> None: + super().__init__(config) + if not config.twitter_bearer_token: + raise APIFetchError("Twitter bearer token is missing") + + async def fetch(self, query: str, limit: int) -> List[SocialPost]: + params = { + "query": query, + "max_results": min(max(limit, 10), 100), + "tweet.fields": "created_at,author_id,lang" + } + headers = { + "Authorization": f"Bearer {self.config.twitter_bearer_token}", + } + response = await self._request(self.API_URL, params, headers) + self._validate_keys(response, ["data"]) + return [self._build_post(tweet) for tweet in response.get("data", [])] + + def _build_post(self, tweet: Dict[str, Any]) -> SocialPost: + text = sanitize_text(tweet.get("text", "")) + created_at = self._parse_timestamp(tweet.get("created_at")) or datetime.utcnow() + return SocialPost( + external_id=str(tweet.get("id")), + text=text, + author=str(tweet.get("author_id", "unknown")), + source=SocialSource.TWITTER, + created_at=created_at, + url=f"https://twitter.com/i/web/status/{tweet.get('id')}", + metadata={"raw": tweet}, + ) diff --git a/sentiment/logging_utils.py b/sentiment/logging_utils.py new file mode 100644 index 0000000..cad5ed6 --- /dev/null +++ b/sentiment/logging_utils.py @@ -0,0 +1,32 @@ +"""Logging helpers for the sentiment analysis project.""" + +from __future__ import annotations + +import logging +from typing import Optional + + +def configure_logging(level: str = "INFO") -> logging.Logger: + """Configure and return the application logger. + + Args: + level: Desired log level name. + + Returns: + Configured root logger. + """ + + numeric_level = getattr(logging, level.upper(), logging.INFO) + logger = logging.getLogger("sentiment") + + if not logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + + logger.setLevel(numeric_level) + logger.propagate = False + return logger diff --git a/sentiment/metrics.py b/sentiment/metrics.py new file mode 100644 index 0000000..6747ede --- /dev/null +++ b/sentiment/metrics.py @@ -0,0 +1,56 @@ +"""Metrics tracking for the sentiment pipeline.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict + + +@dataclass +class MetricsTracker: + """Track ingestion metrics.""" + + posts_processed: int = 0 + duplicates_skipped: int = 0 + failures: int = 0 + sentiment_distribution: Dict[str, int] = field(default_factory=lambda: { + "positive": 0, + "negative": 0, + "neutral": 0, + }) + + def record_sentiment(self, label: str) -> None: + """Update counters for sentiment labels.""" + + self.posts_processed += 1 + self.sentiment_distribution[label] = self.sentiment_distribution.get(label, 0) + 1 + + def record_duplicate(self) -> None: + """Mark that a duplicate post has been skipped.""" + + self.duplicates_skipped += 1 + + def record_failure(self) -> None: + """Increment failure counter.""" + + self.failures += 1 + + def export(self, path: str | Path) -> None: + """Persist metrics to disk.""" + + destination = Path(path) + destination.parent.mkdir(parents=True, exist_ok=True) + with destination.open("w", encoding="utf-8") as handle: + json.dump(self.as_dict(), handle, indent=2) + + def as_dict(self) -> Dict[str, int | Dict[str, int]]: + """Serialize metrics to a dictionary.""" + + return { + "posts_processed": self.posts_processed, + "duplicates_skipped": self.duplicates_skipped, + "failures": self.failures, + "sentiment_distribution": self.sentiment_distribution, + } diff --git a/sentiment/models.py b/sentiment/models.py new file mode 100644 index 0000000..b268441 --- /dev/null +++ b/sentiment/models.py @@ -0,0 +1,53 @@ +"""Data models used throughout the sentiment pipeline.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + + +class SocialSource(str, Enum): + """Supported social data sources.""" + + TWITTER = "twitter" + REDDIT = "reddit" + + +class SentimentLabel(str, Enum): + """Possible sentiment labels.""" + + POSITIVE = "positive" + NEGATIVE = "negative" + NEUTRAL = "neutral" + + +@dataclass(slots=True) +class SentimentResult: + """Represents the sentiment of a chunk of text.""" + + score: float + label: SentimentLabel + + +@dataclass(slots=True) +class SocialPost: + """Normalized social post structure.""" + + external_id: str + text: str + author: str + source: SocialSource + created_at: datetime + url: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + sentiment: Optional[SentimentResult] = None + + +@dataclass(slots=True) +class PipelineResult: + """Holds aggregate pipeline output.""" + + posts: List[SocialPost] + metrics: Dict[str, Any] diff --git a/sentiment/pipeline.py b/sentiment/pipeline.py new file mode 100644 index 0000000..1adcfc1 --- /dev/null +++ b/sentiment/pipeline.py @@ -0,0 +1,124 @@ +"""Sentiment pipeline orchestration.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Iterable, List, Sequence + +from .analyzer import SentimentAnalyzer +from .config import AppConfig +from .exceptions import DatabaseError, ValidationError +from .exporter import DataExporter +from .metrics import MetricsTracker +from .models import PipelineResult, SocialPost +from .repository import DataRepository +from .sanitizer import chunk_items, validate_post +from .fetchers.base import BaseFetcher + + +class SentimentPipeline: + """Coordinates fetching, analyzing, and persisting social posts.""" + + def __init__( + self, + *, + config: AppConfig, + repository: DataRepository, + analyzer: SentimentAnalyzer, + fetchers: Sequence[BaseFetcher], + exporter: DataExporter, + metrics: MetricsTracker, + logger: logging.Logger | None = None, + ) -> None: + self.config = config + self.repository = repository + self.analyzer = analyzer + self.fetchers = list(fetchers) + self.exporter = exporter + self.metrics = metrics + self.logger = logger or logging.getLogger(self.__class__.__name__) + self._semaphore = asyncio.Semaphore(config.concurrency) + + async def run( + self, + *, + query: str, + limit: int, + sentiment_filter: Sequence[str] | None = None, + sort_by: str = "created_at", + descending: bool = False, + export: bool = False, + ) -> PipelineResult: + """Execute the sentiment ingestion pipeline.""" + + fetched_posts = await self._gather_posts(query=query, limit=limit) + await self._process_posts(fetched_posts) + + posts = self.repository.fetch_posts( + sentiment_filter=sentiment_filter, + sort_by=sort_by, + descending=descending, + ) + if export: + filename = "sentiment_results" + self.exporter.export(posts, filename=filename, fmt=self.config.export_format) + + self.metrics.export(self.config.metrics_export_path) + return PipelineResult(posts=posts, metrics=self.metrics.as_dict()) + + async def _gather_posts(self, *, query: str, limit: int) -> List[SocialPost]: + tasks = [self._fetch_with_guard(fetcher, query, limit) for fetcher in self.fetchers] + results = await asyncio.gather(*tasks, return_exceptions=True) + posts: list[SocialPost] = [] + for result in results: + if isinstance(result, Exception): + self.metrics.record_failure() + self.logger.error("Fetcher failed", exc_info=result) + continue + posts.extend(result) + return posts + + async def _fetch_with_guard( + self, + fetcher: BaseFetcher, + query: str, + limit: int, + ) -> List[SocialPost]: + async with self._semaphore: + return await asyncio.wait_for( + fetcher.fetch(query, limit), + timeout=self.config.fetch_timeout, + ) + + async def _process_posts(self, posts: Sequence[SocialPost]) -> None: + total = len(posts) + if total == 0: + self.logger.info("No posts returned by fetchers") + return + + processed = 0 + for chunk in chunk_items(posts, self.config.batch_size): + await asyncio.gather(*(self._process_single(post) for post in chunk)) + processed += len(chunk) + self.logger.info("Processed %s/%s posts", processed, total) + + async def _process_single(self, post: SocialPost) -> None: + try: + validated = validate_post(post, self.config.timezone) + sentiment = self.analyzer.analyze(validated.text) + validated.sentiment = sentiment + inserted = await asyncio.to_thread(self.repository.insert_post, validated) + if inserted: + self.metrics.record_sentiment(sentiment.label.value) + else: + self.metrics.record_duplicate() + except ValidationError as exc: + self.metrics.record_failure() + self.logger.warning("Validation failed: %s", exc) + except DatabaseError as exc: + self.metrics.record_failure() + self.logger.error("Database error: %s", exc) + except Exception as exc: # pragma: no cover - safeguard + self.metrics.record_failure() + self.logger.exception("Unexpected processing error: %s", exc) diff --git a/sentiment/rate_limiter.py b/sentiment/rate_limiter.py new file mode 100644 index 0000000..04da9b6 --- /dev/null +++ b/sentiment/rate_limiter.py @@ -0,0 +1,37 @@ +"""Simple asynchronous rate limiter implementation.""" + +from __future__ import annotations + +import asyncio +import time +from collections import deque + +from .exceptions import RateLimitError + + +class AsyncRateLimiter: + """Token bucket style async rate limiter.""" + + def __init__(self, max_calls: int, period: float) -> None: + self._max_calls = max_calls + self._period = period + self._calls = deque[float]() + self._lock = asyncio.Lock() + + async def acquire(self) -> None: + """Acquire permission to perform an operation respecting the rate limit.""" + + while True: + async with self._lock: + now = time.monotonic() + while self._calls and now - self._calls[0] > self._period: + self._calls.popleft() + + if len(self._calls) < self._max_calls: + self._calls.append(now) + return + + sleep_for = self._period - (now - self._calls[0]) + if sleep_for > self._period * 2: + raise RateLimitError("Rate limiter queue overflowed") + await asyncio.sleep(max(0.0, sleep_for)) diff --git a/sentiment/repository.py b/sentiment/repository.py new file mode 100644 index 0000000..eb3bfce --- /dev/null +++ b/sentiment/repository.py @@ -0,0 +1,352 @@ +"""Database repository with connection pooling and migrations.""" + +from __future__ import annotations + +import json +import logging +import sqlite3 +import threading +from collections.abc import Iterable, Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from queue import Queue +from typing import Any, List, Sequence +from urllib.parse import urlparse + +try: # pragma: no cover - optional dependency at runtime + from psycopg2.pool import SimpleConnectionPool # type: ignore +except Exception: # pragma: no cover - handle environments without psycopg2 + SimpleConnectionPool = None # type: ignore + +from .config import AppConfig +from .exceptions import DatabaseError +from .models import SentimentLabel, SentimentResult, SocialPost, SocialSource + + +class SQLiteConnectionPool: + """Simple connection pool for sqlite3 connections.""" + + def __init__(self, database: str, max_size: int) -> None: + self._database = database + self._max_size = max_size + self._lock = threading.Lock() + self._created = 0 + self._queue: "Queue[sqlite3.Connection]" = Queue() + + def getconn(self) -> sqlite3.Connection: + with self._lock: + if not self._queue.empty(): + return self._queue.get() + if self._created < self._max_size: + conn = sqlite3.connect( + self._database, + detect_types=sqlite3.PARSE_DECLTYPES, + check_same_thread=False, + ) + conn.row_factory = sqlite3.Row + self._created += 1 + return conn + return self._queue.get() + + def putconn(self, connection: sqlite3.Connection) -> None: + self._queue.put(connection) + + def closeall(self) -> None: + while not self._queue.empty(): + conn = self._queue.get() + conn.close() + + +@dataclass +class SchemaManager: + """Apply schema migrations sequentially.""" + + repository: "DataRepository" + target_version: int + + def apply_migrations(self) -> None: + current_version = self._get_current_version() + while current_version < self.target_version: + next_version = current_version + 1 + migration = getattr(self, f"_migration_v{next_version}", None) + if migration is None: + raise DatabaseError(f"Missing migration for version {next_version}") + migration() + self._record_version(next_version) + current_version = next_version + + def _get_current_version(self) -> int: + self.repository._ensure_migrations_table() + rows = self.repository._execute("SELECT MAX(version) as version FROM schema_migrations", fetch=True) + if not rows or rows[0]["version"] is None: + return 0 + return int(rows[0]["version"]) + + def _record_version(self, version: int) -> None: + placeholder = self.repository.placeholder + query = f"INSERT INTO schema_migrations (version, applied_at) VALUES ({placeholder}, {placeholder})" + now = datetime.now(timezone.utc).isoformat() + self.repository._execute(query, (version, now)) + + def _migration_v1(self) -> None: + driver = self.repository.driver + timestamp_type = "TIMESTAMPTZ" if driver == "postgres" else "TEXT" + metadata_type = "JSONB" if driver == "postgres" else "TEXT" + id_column = ( + "SERIAL PRIMARY KEY" + if driver == "postgres" + else "INTEGER PRIMARY KEY AUTOINCREMENT" + ) + + schema_statements = [ + f""" + CREATE TABLE IF NOT EXISTS social_posts ( + id {id_column}, + external_id TEXT UNIQUE NOT NULL, + text TEXT NOT NULL, + author TEXT NOT NULL, + source TEXT NOT NULL, + created_at {timestamp_type} NOT NULL, + url TEXT, + sentiment_label TEXT, + sentiment_score REAL, + metadata {metadata_type}, + inserted_at {timestamp_type} DEFAULT CURRENT_TIMESTAMP + ); + """, + ] + + for statement in schema_statements: + self.repository._execute(statement) + + +class DataRepository: + """Provides database access with pooling and schema management.""" + + def __init__(self, config: AppConfig, logger: logging.Logger | None = None) -> None: + self.config = config + self.logger = logger or logging.getLogger(self.__class__.__name__) + self.driver: str + self.placeholder: str + self._pool = self._create_pool() + self.schema_manager = SchemaManager(self, config.schema_version) + self.schema_manager.apply_migrations() + + def _create_pool(self) -> Any: + parsed = urlparse(self.config.database_url) + scheme = (parsed.scheme or "sqlite").lower() + if scheme.startswith("postgres"): + if SimpleConnectionPool is None: + raise DatabaseError("psycopg2 is required for PostgreSQL connections") + self.driver = "postgres" + self.placeholder = "%s" + dbname = (parsed.path or "/")[1:] + pool = SimpleConnectionPool( + self.config.min_pool_size, + self.config.max_pool_size, + user=parsed.username, + password=parsed.password, + host=parsed.hostname, + port=parsed.port or 5432, + database=dbname, + connect_timeout=int(self.config.request_timeout), + ) + return pool + + if scheme.startswith("sqlite"): + self.driver = "sqlite" + self.placeholder = "?" + path = parsed.path or parsed.netloc + if path.startswith("/"): + db_path = path + else: + db_path = f"./{path}" if path else "./sentiment.sqlite3" + Path(db_path).parent.mkdir(parents=True, exist_ok=True) + return SQLiteConnectionPool(db_path, self.config.max_pool_size) + + raise DatabaseError(f"Unsupported database scheme: {scheme}") + + @contextmanager + def get_connection(self) -> Iterator[Any]: + connection = self._acquire_connection() + try: + yield connection + connection.commit() + except sqlite3.Error as exc: # pragma: no cover - exercised indirectly + connection.rollback() + self.logger.error("SQLite error: %s", exc) + raise DatabaseError(str(exc)) from exc + except Exception as exc: # pragma: no cover - exercised indirectly + connection.rollback() + self.logger.error("Database error: %s", exc) + raise + finally: + self._release_connection(connection) + + def _acquire_connection(self) -> Any: + if self.driver == "postgres": + return self._pool.getconn() + return self._pool.getconn() + + def _release_connection(self, connection: Any) -> None: + if self.driver == "postgres": + self._pool.putconn(connection) + else: + self._pool.putconn(connection) + + def _ensure_migrations_table(self) -> None: + timestamp_type = "TIMESTAMPTZ" if self.driver == "postgres" else "TEXT" + statement = f""" + CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at {timestamp_type} NOT NULL + ); + """ + self._execute(statement) + + def insert_post(self, post: SocialPost) -> bool: + """Insert a post if it does not already exist.""" + + exists_query = ( + "SELECT 1 FROM social_posts WHERE external_id = " + f"{self.placeholder} LIMIT 1" + ) + params = (post.external_id,) + rows = self._execute(exists_query, params, fetch=True) + if rows: + return False + + insert_query = ( + "INSERT INTO social_posts (external_id, text, author, source, created_at, url, " + "sentiment_label, sentiment_score, metadata) " + f"VALUES ({', '.join([self.placeholder] * 9)})" + ) + try: + metadata_value = json.dumps(post.metadata) + except TypeError: + metadata_value = json.dumps({"raw": str(post.metadata)}) + sentiment_label = ( + post.sentiment.label.value if post.sentiment else None + ) + sentiment_score = ( + float(post.sentiment.score) if post.sentiment else None + ) + created_at_value = ( + post.created_at if self.driver == "postgres" else post.created_at.isoformat() + ) + insert_params = ( + post.external_id, + post.text, + post.author, + post.source.value, + created_at_value, + post.url, + sentiment_label, + sentiment_score, + metadata_value, + ) + self._execute(insert_query, insert_params) + return True + + def fetch_posts( + self, + *, + sentiment_filter: Sequence[str] | None = None, + limit: int | None = None, + sort_by: str = "created_at", + descending: bool = False, + ) -> List[SocialPost]: + """Fetch posts optionally filtered and sorted.""" + + allowed_sort = {"created_at", "author", "sentiment_score"} + if sort_by not in allowed_sort: + raise ValueError("Unsupported sort column") + + clauses: list[str] = [] + params: list[Any] = [] + if sentiment_filter: + placeholders = ", ".join([self.placeholder] * len(sentiment_filter)) + clauses.append(f"sentiment_label IN ({placeholders})") + params.extend(sentiment_filter) + + where_sql = f" WHERE {' AND '.join(clauses)}" if clauses else "" + order_sql = f" ORDER BY {sort_by} {'DESC' if descending else 'ASC'}" + limit_sql = f" LIMIT {limit}" if limit else "" + + query = ( + "SELECT external_id, text, author, source, created_at, url, sentiment_label, sentiment_score, metadata " + "FROM social_posts" + f"{where_sql}{order_sql}{limit_sql}" + ) + rows = self._execute(query, tuple(params), fetch=True) or [] + return [self._row_to_post(row) for row in rows] + + def export(self, exporter, *, sentiment_filter: Sequence[str] | None = None, filename: str = "posts") -> None: + posts = self.fetch_posts(sentiment_filter=sentiment_filter) + exporter.export(posts, filename=filename) + + def close(self) -> None: + """Close all pooled connections.""" + + if self.driver == "postgres": # pragma: no cover - requires postgres runtime + self._pool.closeall() + else: + self._pool.closeall() + + # Internal helpers ------------------------------------------------- + + def _execute( + self, + query: str, + params: Sequence[Any] | None = None, + *, + fetch: bool = False, + ) -> List[dict[str, Any]] | None: + params = params or [] + with self.get_connection() as connection: + cursor = connection.cursor() + try: + cursor.execute(query, params) + if fetch: + columns = [desc[0] for desc in cursor.description] # type: ignore[assignment] + rows = [dict(zip(columns, row)) for row in cursor.fetchall()] + return rows + finally: + cursor.close() + return None + + def _row_to_post(self, row: dict[str, Any]) -> SocialPost: + created_at_raw = row["created_at"] + if isinstance(created_at_raw, datetime): + created_at = created_at_raw + else: + created_at = datetime.fromisoformat(str(created_at_raw)) + + metadata_value = row.get("metadata") + if isinstance(metadata_value, str): + try: + metadata = json.loads(metadata_value) + except json.JSONDecodeError: + metadata = {"raw": metadata_value} + else: + metadata = metadata_value or {} + + sentiment = None + if row.get("sentiment_label"): + sentiment = SentimentResult( + score=float(row.get("sentiment_score") or 0.0), + label=SentimentLabel(row["sentiment_label"]), + ) + + return SocialPost( + external_id=row["external_id"], + text=row["text"], + author=row["author"], + source=SocialSource(row["source"]), + created_at=created_at, + url=row.get("url"), + metadata=metadata, + sentiment=sentiment, + ) diff --git a/sentiment/retry.py b/sentiment/retry.py new file mode 100644 index 0000000..4b65380 --- /dev/null +++ b/sentiment/retry.py @@ -0,0 +1,42 @@ +"""Async retry helpers.""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Awaitable, Callable +from typing import Tuple, TypeVar + +T = TypeVar("T") + + +async def execute_with_retry( + operation: Callable[[], Awaitable[T]], + *, + max_attempts: int, + backoff_factor: float, + retry_exceptions: Tuple[type[BaseException], ...], + logger: logging.Logger, + base_delay: float = 1.0, +) -> T: + """Execute an async callable with exponential backoff.""" + + attempt = 0 + while True: + try: + return await operation() + except retry_exceptions as exc: # type: ignore[arg-type] + attempt += 1 + if attempt >= max_attempts: + logger.error("Operation failed after %s attempts", attempt, exc_info=exc) + raise + + delay = base_delay * (backoff_factor ** (attempt - 1)) + logger.warning( + "Operation failed with %s. Retrying in %.2fs (attempt %s/%s)", + exc, + delay, + attempt + 1, + max_attempts, + ) + await asyncio.sleep(delay) diff --git a/sentiment/sanitizer.py b/sentiment/sanitizer.py new file mode 100644 index 0000000..1fc8443 --- /dev/null +++ b/sentiment/sanitizer.py @@ -0,0 +1,59 @@ +"""Utilities for validating and sanitizing incoming data.""" + +from __future__ import annotations + +import re +from datetime import datetime +from typing import Iterable + +from zoneinfo import ZoneInfo + +from .exceptions import ValidationError +from .models import SocialPost + +_TEXT_FILTER = re.compile(r"[^\w\s.,!?/:;-]") +_WHITESPACE_RE = re.compile(r"\s+") + + +def sanitize_text(text: str) -> str: + """Normalize whitespace and strip unwanted characters.""" + + normalized = _WHITESPACE_RE.sub(" ", (text or "").strip()) + return _TEXT_FILTER.sub("", normalized) + + +def ensure_timezone(timestamp: datetime, timezone: str) -> datetime: + """Return the timestamp converted to the configured timezone.""" + + zone = ZoneInfo(timezone) + if timestamp.tzinfo is None: + return timestamp.replace(tzinfo=zone) + return timestamp.astimezone(zone) + + +def validate_post(post: SocialPost, timezone: str) -> SocialPost: + """Validate mandatory fields and normalize timestamps.""" + + if not post.external_id: + raise ValidationError("Post is missing an external identifier") + + sanitized_text = sanitize_text(post.text) + if not sanitized_text: + raise ValidationError("Post does not contain analyzable text") + + post.text = sanitized_text + post.created_at = ensure_timezone(post.created_at, timezone) + return post + + +def chunk_items(items: Iterable[SocialPost], chunk_size: int) -> Iterable[list[SocialPost]]: + """Yield posts in batches for easier processing.""" + + chunk: list[SocialPost] = [] + for item in items: + chunk.append(item) + if len(chunk) == chunk_size: + yield chunk + chunk = [] + if chunk: + yield chunk diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py new file mode 100644 index 0000000..e2db719 --- /dev/null +++ b/tests/test_analyzer.py @@ -0,0 +1,39 @@ +"""Tests for the sentiment analyzer service.""" + +from __future__ import annotations + +import unittest + +from sentiment.analyzer import SentimentAnalyzer +from sentiment.config import AppConfig +from sentiment.models import SentimentLabel + + +class SentimentAnalyzerTests(unittest.TestCase): + """Validate sentiment scoring.""" + + def setUp(self) -> None: + self.config = AppConfig( + database_url="sqlite:///tmp/test.db", + twitter_bearer_token="token", + reddit_client_id="client", + reddit_client_secret="secret", + sources=[], + ) + self.analyzer = SentimentAnalyzer(self.config) + + def test_positive_sentiment(self) -> None: + result = self.analyzer.analyze("This is a great and awesome day") + self.assertEqual(result.label, SentimentLabel.POSITIVE) + + def test_negative_sentiment(self) -> None: + result = self.analyzer.analyze("This is a bad and terrible idea") + self.assertEqual(result.label, SentimentLabel.NEGATIVE) + + def test_neutral_sentiment(self) -> None: + result = self.analyzer.analyze("The sky is blue") + self.assertEqual(result.label, SentimentLabel.NEUTRAL) + + +if __name__ == "__main__": # pragma: no cover + unittest.main() diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..6532590 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,43 @@ +"""Tests for configuration management.""" + +from __future__ import annotations + +import os +import unittest +from unittest import mock + +from sentiment.config import AppConfig +from sentiment.exceptions import ConfigurationError + + +class AppConfigTests(unittest.TestCase): + """Validate AppConfig loading and validation.""" + + def test_from_env_missing_database_url(self) -> None: + with mock.patch.dict(os.environ, { + "TWITTER_BEARER_TOKEN": "token", + "REDDIT_CLIENT_ID": "client", + "REDDIT_CLIENT_SECRET": "secret", + }, clear=True): + with self.assertRaises(ConfigurationError): + AppConfig.from_env() + + def test_from_env_success(self) -> None: + with mock.patch.dict(os.environ, { + "DATABASE_URL": "sqlite:///tmp/test.db", + "TWITTER_BEARER_TOKEN": "token", + "REDDIT_CLIENT_ID": "client", + "REDDIT_CLIENT_SECRET": "secret", + "SENTIMENT_SOURCES": "twitter,reddit", + "REQUEST_TIMEOUT": "5", + "BATCH_SIZE": "10", + }, clear=True): + config = AppConfig.from_env() + self.assertEqual(config.database_url, "sqlite:///tmp/test.db") + self.assertEqual(config.batch_size, 10) + self.assertIn("twitter", config.enabled_sources) + self.assertIn("reddit", config.enabled_sources) + + +if __name__ == "__main__": # pragma: no cover + unittest.main() diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..1537638 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,89 @@ +"""Pipeline integration tests.""" + +from __future__ import annotations + +import tempfile +import unittest +from datetime import datetime, timezone +from pathlib import Path + +from sentiment.analyzer import SentimentAnalyzer +from sentiment.config import AppConfig +from sentiment.exporter import DataExporter +from sentiment.metrics import MetricsTracker +from sentiment.models import SocialPost, SocialSource +from sentiment.pipeline import SentimentPipeline +from sentiment.repository import DataRepository +from sentiment.fetchers.base import BaseFetcher + + +class DummyFetcher(BaseFetcher): + """Test double returning pre-defined posts.""" + + def __init__(self, config: AppConfig, posts: list[SocialPost]): + super().__init__(config) + self._posts = posts + + async def fetch(self, query: str, limit: int): + return self._posts[:limit] + + +class PipelineTests(unittest.IsolatedAsyncioTestCase): + """Validate the pipeline end-to-end using sqlite.""" + + async def asyncSetUp(self) -> None: + self.tmp_dir = tempfile.TemporaryDirectory() + self.export_dir = Path(self.tmp_dir.name) / "exports" + db_path = f"sqlite:///{self.tmp_dir.name}/pipeline.sqlite3" + self.config = AppConfig( + database_url=db_path, + twitter_bearer_token="token", + reddit_client_id="client", + reddit_client_secret="secret", + sources=[], + batch_size=2, + ) + self.repository = DataRepository(self.config) + self.analyzer = SentimentAnalyzer(self.config) + self.exporter = DataExporter(self.export_dir) + self.metrics = MetricsTracker() + now = datetime.now(tz=timezone.utc) + posts = [ + SocialPost( + external_id="1", + text="I love async pipelines", + author="alice", + source=SocialSource.TWITTER, + created_at=now, + ), + SocialPost( + external_id="2", + text="I hate slow databases", + author="bob", + source=SocialSource.REDDIT, + created_at=now, + ), + ] + self.fetcher = DummyFetcher(self.config, posts) + self.pipeline = SentimentPipeline( + config=self.config, + repository=self.repository, + analyzer=self.analyzer, + fetchers=[self.fetcher], + exporter=self.exporter, + metrics=self.metrics, + ) + + async def asyncTearDown(self) -> None: + self.repository.close() + self.tmp_dir.cleanup() + + async def test_pipeline_runs_and_persists_posts(self) -> None: + result = await self.pipeline.run(query="ai", limit=10, export=False) + self.assertGreaterEqual(result.metrics["posts_processed"], 1) + posts = self.repository.fetch_posts() + self.assertEqual(len(posts), 2) + + +if __name__ == "__main__": # pragma: no cover + unittest.main() diff --git a/tests/test_repository.py b/tests/test_repository.py new file mode 100644 index 0000000..3d94aba --- /dev/null +++ b/tests/test_repository.py @@ -0,0 +1,63 @@ +"""Tests for the data repository.""" + +from __future__ import annotations + +import tempfile +import unittest +from datetime import datetime, timezone + +from sentiment.config import AppConfig +from sentiment.models import SentimentLabel, SentimentResult, SocialPost, SocialSource +from sentiment.repository import DataRepository + + +class DataRepositoryTests(unittest.TestCase): + """Validate repository behaviour using sqlite.""" + + def setUp(self) -> None: + self._tmp_dir = tempfile.TemporaryDirectory() + db_path = f"sqlite:///{self._tmp_dir.name}/test.sqlite3" + self.config = AppConfig( + database_url=db_path, + twitter_bearer_token="token", + reddit_client_id="client", + reddit_client_secret="secret", + sources=[], + ) + self.repository = DataRepository(self.config) + + def tearDown(self) -> None: + self.repository.close() + self._tmp_dir.cleanup() + + def _build_post(self, external_id: str, sentiment_label: SentimentLabel) -> SocialPost: + post = SocialPost( + external_id=external_id, + text="I love testing" if sentiment_label == SentimentLabel.POSITIVE else "I hate bugs", + author="tester", + source=SocialSource.TWITTER, + created_at=datetime.now(tz=timezone.utc), + ) + post.sentiment = SentimentResult(score=0.9 if sentiment_label == SentimentLabel.POSITIVE else -0.9, label=sentiment_label) + return post + + def test_insert_and_fetch_posts(self) -> None: + post = self._build_post("abc", SentimentLabel.POSITIVE) + self.assertTrue(self.repository.insert_post(post)) + self.assertFalse(self.repository.insert_post(post)) # duplicate + posts = self.repository.fetch_posts() + self.assertEqual(len(posts), 1) + self.assertEqual(posts[0].external_id, "abc") + + def test_fetch_with_sentiment_filter(self) -> None: + positive = self._build_post("p", SentimentLabel.POSITIVE) + negative = self._build_post("n", SentimentLabel.NEGATIVE) + self.repository.insert_post(positive) + self.repository.insert_post(negative) + filtered = self.repository.fetch_posts(sentiment_filter=[SentimentLabel.POSITIVE.value]) + self.assertEqual(len(filtered), 1) + self.assertEqual(filtered[0].external_id, "p") + + +if __name__ == "__main__": # pragma: no cover + unittest.main() diff --git a/tests/test_sanitizer.py b/tests/test_sanitizer.py new file mode 100644 index 0000000..d53a7d0 --- /dev/null +++ b/tests/test_sanitizer.py @@ -0,0 +1,33 @@ +"""Tests for sanitizer utilities.""" + +from __future__ import annotations + +import unittest +from datetime import datetime + +from sentiment.models import SocialPost, SocialSource +from sentiment.sanitizer import sanitize_text, validate_post + + +class SanitizerTests(unittest.TestCase): + """Ensure sanitizer works as expected.""" + + def test_sanitize_text(self) -> None: + text = " Hello\tWorld!!@@## " + self.assertEqual(sanitize_text(text), "Hello World!!") + + def test_validate_post_normalizes_timezone(self) -> None: + post = SocialPost( + external_id="1", + text="Some text!!!", + author="alice", + source=SocialSource.TWITTER, + created_at=datetime(2024, 1, 1, 12, 0, 0), + ) + validated = validate_post(post, "UTC") + self.assertIsNotNone(validated.created_at.tzinfo) + self.assertEqual(validated.text, "Some text!!!") + + +if __name__ == "__main__": # pragma: no cover + unittest.main()