From 60a070dbe78fbf662b1d2513336aa8324687452c Mon Sep 17 00:00:00 2001 From: Jake Bromberg Date: Sat, 14 Feb 2026 14:53:34 -0800 Subject: [PATCH 1/3] feat: defer track import until after dedup for 88% fewer track rows Track tables (release_track, release_track_artist) are now imported after dedup instead of before, avoiding importing/deduplicating/indexing millions of track rows that would be discarded. Pre-computed track counts from CSV drive the dedup ranking instead of a live JOIN on release_track. New pipeline step order (8 steps, pipeline state v2): create_schema -> import_csv (base only) -> create_indexes (base only) -> dedup (base only) -> import_tracks -> create_track_indexes -> prune -> vacuum Key changes: - Split TABLES into BASE_TABLES + TRACK_TABLES with --base-only/--tracks-only flags - Pre-compute release_track_count table from CSV for dedup ranking - Filter track import to surviving release IDs after dedup - Split add_constraints_and_indexes() into base and track versions - Split create_indexes.sql; new create_track_indexes.sql for track indexes - Split trigram_indexes_exist() into base and track variants - Pipeline state v2 with v1 migration support - Dedup falls back to release_track if release_track_count doesn't exist --- lib/db_introspect.py | 40 ++++- lib/pipeline_state.py | 58 +++++++- schema/create_indexes.sql | 28 +--- schema/create_track_indexes.sql | 52 +++++++ scripts/dedup_releases.py | 145 +++++++++++++----- scripts/import_csv.py | 188 +++++++++++++++++++---- scripts/run_pipeline.py | 38 ++++- tests/e2e/test_pipeline.py | 4 +- tests/integration/test_db_introspect.py | 117 ++++++++++++++- tests/integration/test_dedup.py | 166 +++++++++++++++++---- tests/integration/test_import.py | 190 ++++++++++++++++++++++++ tests/integration/test_schema.py | 165 +++++++++++++++++--- tests/unit/test_import_csv.py | 71 +++++++++ tests/unit/test_pipeline_state.py | 136 ++++++++++++++++- 14 files changed, 1235 insertions(+), 163 deletions(-) create mode 100644 schema/create_track_indexes.sql diff --git a/lib/db_introspect.py b/lib/db_introspect.py index dc05b8f..0594fa3 100644 --- a/lib/db_introspect.py +++ b/lib/db_introspect.py @@ -53,8 +53,8 @@ def column_exists(db_url: str, table_name: str, column_name: str) -> bool: return result -def trigram_indexes_exist(db_url: str) -> bool: - """Return True if trigram GIN indexes exist on the expected tables.""" +def _get_trigram_indexes(db_url: str) -> set[str]: + """Return the set of trigram index names in the public schema.""" conn = psycopg.connect(db_url) with conn.cursor() as cur: cur.execute( @@ -63,15 +63,37 @@ def trigram_indexes_exist(db_url: str) -> bool: ) indexes = {row[0] for row in cur.fetchall()} conn.close() + return indexes + + +def base_trigram_indexes_exist(db_url: str) -> bool: + """Return True if base trigram GIN indexes exist (release, release_artist).""" + indexes = _get_trigram_indexes(db_url) expected = { - "idx_release_track_title_trgm", "idx_release_artist_name_trgm", - "idx_release_track_artist_name_trgm", "idx_release_title_trgm", } return expected.issubset(indexes) +def track_trigram_indexes_exist(db_url: str) -> bool: + """Return True if track trigram GIN indexes exist (release_track, release_track_artist).""" + indexes = _get_trigram_indexes(db_url) + expected = { + "idx_release_track_title_trgm", + "idx_release_track_artist_name_trgm", + } + return expected.issubset(indexes) + + +def trigram_indexes_exist(db_url: str) -> bool: + """Return True if all trigram GIN indexes exist (base + track). + + Backward-compatible convenience function. + """ + return base_trigram_indexes_exist(db_url) and track_trigram_indexes_exist(db_url) + + def infer_pipeline_state(db_url: str) -> PipelineState: """Infer pipeline state from database structure. @@ -91,7 +113,7 @@ def infer_pipeline_state(db_url: str) -> PipelineState: return state state.mark_completed("import_csv") - if not trigram_indexes_exist(db_url): + if not base_trigram_indexes_exist(db_url): return state state.mark_completed("create_indexes") @@ -99,5 +121,13 @@ def infer_pipeline_state(db_url: str) -> PipelineState: return state state.mark_completed("dedup") + if not table_has_rows(db_url, "release_track"): + return state + state.mark_completed("import_tracks") + + if not track_trigram_indexes_exist(db_url): + return state + state.mark_completed("create_track_indexes") + # prune and vacuum cannot be inferred from database state return state diff --git a/lib/pipeline_state.py b/lib/pipeline_state.py index 07b356f..60da0a6 100644 --- a/lib/pipeline_state.py +++ b/lib/pipeline_state.py @@ -9,9 +9,21 @@ import json from pathlib import Path -VERSION = 1 +VERSION = 2 -STEP_NAMES = ["create_schema", "import_csv", "create_indexes", "dedup", "prune", "vacuum"] +STEP_NAMES = [ + "create_schema", + "import_csv", + "create_indexes", + "dedup", + "import_tracks", + "create_track_indexes", + "prune", + "vacuum", +] + +# Mapping from v1 step names to v2 equivalents for migration +_V1_STEP_NAMES = ["create_schema", "import_csv", "create_indexes", "dedup", "prune", "vacuum"] class PipelineState: @@ -64,11 +76,51 @@ def save(self, path: Path) -> None: @classmethod def load(cls, path: Path) -> PipelineState: - """Load state from a JSON file.""" + """Load state from a JSON file. + + Supports v1 state files by migrating them to v2 format. + """ data = json.loads(path.read_text()) version = data.get("version") + + if version == 1: + return cls._migrate_v1(data) if version != VERSION: raise ValueError(f"Unsupported state file version {version} (expected {VERSION})") + state = cls(db_url=data["database_url"], csv_dir=data["csv_dir"]) state._steps = data["steps"] return state + + @classmethod + def _migrate_v1(cls, data: dict) -> PipelineState: + """Migrate a v1 state file to v2 format. + + V2 adds import_tracks and create_track_indexes between dedup and prune. + + Migration rules: + - All v1 steps map directly to their v2 equivalents + - If import_csv was completed in v1, import_tracks is also completed + (v1 imported tracks as part of import_csv) + - If create_indexes or dedup was completed in v1, create_track_indexes + is also completed (v1 created track indexes during those steps) + """ + state = cls(db_url=data["database_url"], csv_dir=data["csv_dir"]) + v1_steps = data.get("steps", {}) + + # Copy v1 steps that exist in v2 + for step_name in _V1_STEP_NAMES: + if step_name in v1_steps: + state._steps[step_name] = v1_steps[step_name] + + # Infer import_tracks from import_csv + if v1_steps.get("import_csv", {}).get("status") == "completed": + state._steps["import_tracks"] = {"status": "completed"} + + # Infer create_track_indexes from dedup (v1 created all indexes in dedup) + if v1_steps.get("dedup", {}).get("status") == "completed": + state._steps["create_track_indexes"] = {"status": "completed"} + elif v1_steps.get("create_indexes", {}).get("status") == "completed": + state._steps["create_track_indexes"] = {"status": "completed"} + + return state diff --git a/schema/create_indexes.sql b/schema/create_indexes.sql index 59623d0..5dd357a 100644 --- a/schema/create_indexes.sql +++ b/schema/create_indexes.sql @@ -1,35 +1,23 @@ --- Create trigram indexes for fuzzy text search --- Run AFTER data import: psql -U postgres -d discogs -f 05-create-indexes.sql +-- Create base trigram indexes for fuzzy text search +-- Run AFTER base data import (release, release_artist). +-- Track-related indexes are in create_track_indexes.sql (run after track import). -- -- These indexes enable fast fuzzy matching using pg_trgm extension. --- Index creation on large tables can take 10-30 minutes. -- Ensure extension is loaded CREATE EXTENSION IF NOT EXISTS pg_trgm; -- ============================================ --- Trigram indexes for fuzzy text search +-- Base trigram indexes for fuzzy text search -- ============================================ --- 1. Track title search: "Find releases containing track 'Blue Monday'" --- Used by: search_releases_by_track() --- Query pattern: WHERE lower(f_unaccent(title)) % $1 OR lower(f_unaccent(title)) ILIKE ... -CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_release_track_title_trgm -ON release_track USING GIN (lower(f_unaccent(title)) gin_trgm_ops); - --- 2. Artist name search on releases: "Find releases by 'New Order'" +-- 1. Artist name search on releases: "Find releases by 'New Order'" -- Used by: search_releases_by_track() artist filter -- Query pattern: WHERE lower(f_unaccent(artist_name)) % $1 CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_release_artist_name_trgm ON release_artist USING GIN (lower(f_unaccent(artist_name)) gin_trgm_ops); --- 3. Track artist search: "Find compilation tracks by 'Joy Division'" --- Used by: validate_track_on_release() for compilations --- Query pattern: WHERE lower(f_unaccent(artist_name)) % $1 -CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_release_track_artist_name_trgm -ON release_track_artist USING GIN (lower(f_unaccent(artist_name)) gin_trgm_ops); - --- 4. Release title search: "Find releases named 'Power, Corruption & Lies'" +-- 2. Release title search: "Find releases named 'Power, Corruption & Lies'" -- Used by: get_release searches CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_release_title_trgm ON release USING GIN (lower(f_unaccent(title)) gin_trgm_ops); @@ -48,6 +36,6 @@ ON release USING GIN (lower(f_unaccent(title)) gin_trgm_ops); -- Test trigram search (should use index) -- EXPLAIN ANALYZE --- SELECT * FROM release_track --- WHERE lower(f_unaccent(title)) % 'blue monday' +-- SELECT * FROM release_artist +-- WHERE lower(f_unaccent(artist_name)) % 'new order' -- LIMIT 10; diff --git a/schema/create_track_indexes.sql b/schema/create_track_indexes.sql new file mode 100644 index 0000000..481a097 --- /dev/null +++ b/schema/create_track_indexes.sql @@ -0,0 +1,52 @@ +-- Create track-related FK constraints, FK indexes, and trigram indexes. +-- Run AFTER track import (release_track, release_track_artist). +-- +-- Base indexes are in create_indexes.sql (run after base import). +-- This file is idempotent: safe to run on resume. + +-- Ensure extension is loaded +CREATE EXTENSION IF NOT EXISTS pg_trgm; + +-- ============================================ +-- FK constraints (idempotent via DO blocks) +-- ============================================ + +DO $$ +BEGIN + ALTER TABLE release_track ADD CONSTRAINT fk_release_track_release + FOREIGN KEY (release_id) REFERENCES release(id) ON DELETE CASCADE; +EXCEPTION WHEN duplicate_object THEN NULL; +END $$; + +DO $$ +BEGIN + ALTER TABLE release_track_artist ADD CONSTRAINT fk_release_track_artist_release + FOREIGN KEY (release_id) REFERENCES release(id) ON DELETE CASCADE; +EXCEPTION WHEN duplicate_object THEN NULL; +END $$; + +-- ============================================ +-- FK indexes +-- ============================================ + +CREATE INDEX IF NOT EXISTS idx_release_track_release_id +ON release_track(release_id); + +CREATE INDEX IF NOT EXISTS idx_release_track_artist_release_id +ON release_track_artist(release_id); + +-- ============================================ +-- Trigram indexes for fuzzy text search +-- ============================================ + +-- Track title search: "Find releases containing track 'Blue Monday'" +-- Used by: search_releases_by_track() +-- Query pattern: WHERE lower(f_unaccent(title)) % $1 +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_release_track_title_trgm +ON release_track USING GIN (lower(f_unaccent(title)) gin_trgm_ops); + +-- Track artist search: "Find compilation tracks by 'Joy Division'" +-- Used by: validate_track_on_release() for compilations +-- Query pattern: WHERE lower(f_unaccent(artist_name)) % $1 +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_release_track_artist_name_trgm +ON release_track_artist USING GIN (lower(f_unaccent(artist_name)) gin_trgm_ops); diff --git a/scripts/dedup_releases.py b/scripts/dedup_releases.py index 39ef5f5..874f7cd 100644 --- a/scripts/dedup_releases.py +++ b/scripts/dedup_releases.py @@ -27,9 +27,24 @@ logger = logging.getLogger(__name__) +def _track_count_table_exists(conn) -> bool: + """Return True if the release_track_count pre-computed table exists.""" + with conn.cursor() as cur: + cur.execute(""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'release_track_count' + ) + """) + return cur.fetchone()[0] + + def ensure_dedup_ids(conn) -> int: """Ensure dedup_delete_ids table exists. Create if needed. + Uses release_track_count table for track counts if available (v2 pipeline), + falling back to counting from release_track directly (v1 / standalone usage). + Returns number of IDs to delete. """ with conn.cursor() as cur: @@ -48,9 +63,28 @@ def ensure_dedup_ids(conn) -> int: logger.info(f"dedup_delete_ids already exists with {count:,} IDs") return count - logger.info("Creating dedup_delete_ids from ROW_NUMBER query...") + # Choose track count source: pre-computed table or live count from release_track + use_precomputed = _track_count_table_exists(conn) + + if use_precomputed: + logger.info( + "Creating dedup_delete_ids from ROW_NUMBER query (using pre-computed track counts)..." + ) + track_count_join = "JOIN release_track_count tc ON tc.release_id = r.id" + else: + logger.info( + "Creating dedup_delete_ids from ROW_NUMBER query (counting from release_track)..." + ) + track_count_join = ( + "JOIN (" + " SELECT release_id, COUNT(*) as track_count" + " FROM release_track" + " GROUP BY release_id" + ") tc ON tc.release_id = r.id" + ) + with conn.cursor() as cur: - cur.execute(""" + cur.execute(f""" CREATE UNLOGGED TABLE dedup_delete_ids AS SELECT id AS release_id FROM ( SELECT r.id, r.master_id, @@ -59,11 +93,7 @@ def ensure_dedup_ids(conn) -> int: ORDER BY tc.track_count DESC, r.id ASC ) as rn FROM release r - JOIN ( - SELECT release_id, COUNT(*) as track_count - FROM release_track - GROUP BY release_id - ) tc ON tc.release_id = r.id + {track_count_join} WHERE r.master_id IS NOT NULL ) ranked WHERE rn > 1 @@ -119,35 +149,29 @@ def swap_tables(conn, old_table: str, new_table: str) -> None: logger.info(f" Swapped {new_table} -> {old_table}") -def add_constraints_and_indexes(conn) -> None: - """Add PK, FK constraints and indexes to the new tables.""" - logger.info("Adding constraints and indexes...") +def add_base_constraints_and_indexes(conn) -> None: + """Add PK, FK constraints and indexes to base tables (no track tables). + + Called after dedup copy-swap. Track constraints are added separately + by create_track_indexes.sql after track import. + """ + logger.info("Adding base constraints and indexes...") start = time.time() statements = [ # Primary key on release "ALTER TABLE release ADD PRIMARY KEY (id)", - # FK constraints with CASCADE + # FK constraints with CASCADE (base tables only) "ALTER TABLE release_artist ADD CONSTRAINT fk_release_artist_release " "FOREIGN KEY (release_id) REFERENCES release(id) ON DELETE CASCADE", - "ALTER TABLE release_track ADD CONSTRAINT fk_release_track_release " - "FOREIGN KEY (release_id) REFERENCES release(id) ON DELETE CASCADE", - "ALTER TABLE release_track_artist ADD CONSTRAINT fk_release_track_artist_release " - "FOREIGN KEY (release_id) REFERENCES release(id) ON DELETE CASCADE", "ALTER TABLE cache_metadata ADD CONSTRAINT fk_cache_metadata_release " "FOREIGN KEY (release_id) REFERENCES release(id) ON DELETE CASCADE", "ALTER TABLE cache_metadata ADD PRIMARY KEY (release_id)", - # FK indexes + # FK indexes (base tables only) "CREATE INDEX idx_release_artist_release_id ON release_artist(release_id)", - "CREATE INDEX idx_release_track_release_id ON release_track(release_id)", - "CREATE INDEX idx_release_track_artist_release_id ON release_track_artist(release_id)", - # Trigram indexes for fuzzy search (accent-insensitive via f_unaccent) - "CREATE INDEX idx_release_track_title_trgm ON release_track " - "USING gin (lower(f_unaccent(title)) gin_trgm_ops)", + # Base trigram indexes for fuzzy search (accent-insensitive via f_unaccent) "CREATE INDEX idx_release_artist_name_trgm ON release_artist " "USING gin (lower(f_unaccent(artist_name)) gin_trgm_ops)", - "CREATE INDEX idx_release_track_artist_name_trgm ON release_track_artist " - "USING gin (lower(f_unaccent(artist_name)) gin_trgm_ops)", "CREATE INDEX idx_release_title_trgm ON release " "USING gin (lower(f_unaccent(title)) gin_trgm_ops)", # Cache metadata indexes @@ -165,7 +189,55 @@ def add_constraints_and_indexes(conn) -> None: logger.info(f" done in {time.time() - stmt_start:.1f}s") elapsed = time.time() - start - logger.info(f"All constraints and indexes added in {elapsed:.1f}s") + logger.info(f"Base constraints and indexes added in {elapsed:.1f}s") + + +def add_track_constraints_and_indexes(conn) -> None: + """Add FK constraints and indexes to track tables. + + Called after track import (post-dedup). Equivalent to running + create_track_indexes.sql. + """ + logger.info("Adding track constraints and indexes...") + start = time.time() + + statements = [ + # FK constraints with CASCADE + "ALTER TABLE release_track ADD CONSTRAINT fk_release_track_release " + "FOREIGN KEY (release_id) REFERENCES release(id) ON DELETE CASCADE", + "ALTER TABLE release_track_artist ADD CONSTRAINT fk_release_track_artist_release " + "FOREIGN KEY (release_id) REFERENCES release(id) ON DELETE CASCADE", + # FK indexes + "CREATE INDEX idx_release_track_release_id ON release_track(release_id)", + "CREATE INDEX idx_release_track_artist_release_id ON release_track_artist(release_id)", + # Track trigram indexes for fuzzy search + "CREATE INDEX idx_release_track_title_trgm ON release_track " + "USING gin (lower(f_unaccent(title)) gin_trgm_ops)", + "CREATE INDEX idx_release_track_artist_name_trgm ON release_track_artist " + "USING gin (lower(f_unaccent(artist_name)) gin_trgm_ops)", + ] + + with conn.cursor() as cur: + for i, stmt in enumerate(statements): + label = stmt.split("(")[0].strip() if "(" in stmt else stmt[:60] + logger.info(f" [{i + 1}/{len(statements)}] {label}...") + stmt_start = time.time() + cur.execute(stmt) + conn.commit() + logger.info(f" done in {time.time() - stmt_start:.1f}s") + + elapsed = time.time() - start + logger.info(f"Track constraints and indexes added in {elapsed:.1f}s") + + +def add_constraints_and_indexes(conn) -> None: + """Add PK, FK constraints and indexes to all tables. + + Convenience function that calls both base and track versions. + Used for backward compatibility (standalone dedup with all tables present). + """ + add_base_constraints_and_indexes(conn) + add_track_constraints_and_indexes(conn) def main(): @@ -178,27 +250,19 @@ def main(): delete_count = ensure_dedup_ids(conn) if delete_count == 0: logger.info("No duplicates found, nothing to do") + # Clean up release_track_count if it exists + with conn.cursor() as cur: + cur.execute("DROP TABLE IF EXISTS release_track_count") conn.close() return total_start = time.time() # Step 2: Copy each table (keeping only non-duplicate rows) + # Only base tables + cache_metadata (tracks are imported after dedup) tables = [ ("release", "new_release", "id, title, release_year, artwork_url", "id"), ("release_artist", "new_release_artist", "release_id, artist_name, extra", "release_id"), - ( - "release_track", - "new_release_track", - "release_id, sequence, position, title, duration", - "release_id", - ), - ( - "release_track_artist", - "new_release_track_artist", - "release_id, track_sequence, artist_name", - "release_id", - ), ( "cache_metadata", "new_cache_metadata", @@ -210,13 +274,11 @@ def main(): for old, new, cols, id_col in tables: copy_table(conn, old, new, cols, id_col) - # Step 3: Drop old tables (order matters for FK constraints) + # Step 3: Drop old FK constraints before swap logger.info("Dropping FK constraints on old tables...") with conn.cursor() as cur: for stmt in [ "ALTER TABLE release_artist DROP CONSTRAINT IF EXISTS fk_release_artist_release", - "ALTER TABLE release_track DROP CONSTRAINT IF EXISTS fk_release_track_release", - "ALTER TABLE release_track_artist DROP CONSTRAINT IF EXISTS fk_release_track_artist_release", "ALTER TABLE cache_metadata DROP CONSTRAINT IF EXISTS fk_cache_metadata_release", ]: cur.execute(stmt) @@ -226,13 +288,14 @@ def main(): for old, new, _, _ in tables: swap_tables(conn, old, new) - # Step 5: Add constraints and indexes - add_constraints_and_indexes(conn) + # Step 5: Add base constraints and indexes + add_base_constraints_and_indexes(conn) # Step 6: Cleanup logger.info("Cleaning up...") with conn.cursor() as cur: cur.execute("DROP TABLE IF EXISTS dedup_delete_ids") + cur.execute("DROP TABLE IF EXISTS release_track_count") # Step 7: Report with conn.cursor() as cur: diff --git a/scripts/import_csv.py b/scripts/import_csv.py index 64bfe83..66cfc77 100644 --- a/scripts/import_csv.py +++ b/scripts/import_csv.py @@ -43,6 +43,23 @@ def extract_year(released: str | None) -> str | None: return None +def count_tracks_from_csv(csv_path: Path) -> dict[int, int]: + """Count tracks per release_id from a release_track CSV file. + + Returns a dict mapping release_id -> track count. + """ + counts: dict[int, int] = {} + with open(csv_path, encoding="utf-8", errors="replace") as f: + reader = csv.DictReader(f) + for row in reader: + try: + release_id = int(row["release_id"]) + except (ValueError, KeyError): + continue + counts[release_id] = counts.get(release_id, 0) + 1 + return counts + + class TableConfig(TypedDict): csv_file: str table: str @@ -52,7 +69,7 @@ class TableConfig(TypedDict): transforms: dict[str, Callable[[str | None], str | None]] -TABLES: list[TableConfig] = [ +BASE_TABLES: list[TableConfig] = [ { "csv_file": "release.csv", "table": "release", @@ -70,6 +87,9 @@ class TableConfig(TypedDict): "transforms": {}, "unique_key": ["release_id", "artist_name"], }, +] + +TRACK_TABLES: list[TableConfig] = [ { "csv_file": "release_track.csv", "table": "release_track", @@ -89,6 +109,8 @@ class TableConfig(TypedDict): }, ] +TABLES: list[TableConfig] = BASE_TABLES + TRACK_TABLES + def import_csv( conn, @@ -99,6 +121,7 @@ def import_csv( required_columns: list[str], transforms: dict, unique_key: list[str] | None = None, + release_id_filter: set[int] | None = None, ) -> int: """Import a CSV file into a table, selecting only needed columns. @@ -108,6 +131,9 @@ def import_csv( If unique_key is provided, duplicate rows (by those CSV columns) are skipped, keeping the first occurrence. + + If release_id_filter is provided, only rows whose release_id is in the + set are imported. The CSV must have a 'release_id' or 'id' column. """ logger.info(f"Importing {csv_path.name} into {table}...") @@ -135,12 +161,32 @@ def import_csv( required_set = set(required_columns) seen: set[tuple[str | None, ...]] = set() + # Determine release_id column name for filtering + release_id_col: str | None = None + if release_id_filter is not None: + for col_name in ("release_id", "id"): + if col_name in csv_columns: + release_id_col = col_name + break + with conn.cursor() as cur: with cur.copy(f"COPY {table} ({db_col_list}) FROM STDIN") as copy: count = 0 skipped = 0 + filtered = 0 dupes = 0 for row in reader: + # Filter by release_id if specified + if release_id_filter is not None and release_id_col is not None: + try: + rid = int(row.get(release_id_col, "")) + except (ValueError, TypeError): + filtered += 1 + continue + if rid not in release_id_filter: + filtered += 1 + continue + # Extract only the columns we need values: list[str | None] = [] skip = False @@ -182,12 +228,46 @@ def import_csv( parts = [f"Imported {count:,} rows"] if skipped > 0: parts.append(f"skipped {skipped:,} with null required fields") + if filtered > 0: + parts.append(f"filtered {filtered:,} by release_id") if dupes > 0: parts.append(f"skipped {dupes:,} duplicates") logger.info(f" {', '.join(parts)}") return count +def create_track_count_table(conn, csv_dir: Path) -> int: + """Pre-compute track counts from CSV and store in release_track_count table. + + Creates an unlogged table with (release_id, track_count) that dedup uses + to rank releases by track count before tracks are imported. + + Returns the number of releases with track counts. + """ + csv_path = csv_dir / "release_track.csv" + if not csv_path.exists(): + logger.warning("release_track.csv not found, skipping track count table") + return 0 + + logger.info("Computing track counts from release_track.csv...") + counts = count_tracks_from_csv(csv_path) + + with conn.cursor() as cur: + cur.execute("DROP TABLE IF EXISTS release_track_count") + cur.execute(""" + CREATE UNLOGGED TABLE release_track_count ( + release_id integer PRIMARY KEY, + track_count integer NOT NULL + ) + """) + with cur.copy("COPY release_track_count (release_id, track_count) FROM STDIN") as copy: + for release_id, track_count in counts.items(): + copy.write_row((release_id, track_count)) + conn.commit() + logger.info(f" Created release_track_count with {len(counts):,} rows") + return len(counts) + + def import_artwork(conn, csv_dir: Path) -> int: """Populate release.artwork_url from release_image.csv. @@ -261,23 +341,15 @@ def import_artwork(conn, csv_dir: Path) -> int: return len(artwork) -def main(): - if len(sys.argv) < 2: - print("Usage: import_csv.py [database_url]") - sys.exit(1) - - csv_dir = Path(sys.argv[1]) - db_url = sys.argv[2] if len(sys.argv) > 2 else "postgresql:///discogs" - - if not csv_dir.exists(): - logger.error(f"CSV directory not found: {csv_dir}") - sys.exit(1) - - logger.info(f"Connecting to {db_url}") - conn = psycopg.connect(db_url) - +def _import_tables( + conn, + csv_dir: Path, + table_list: list[TableConfig], + release_id_filter: set[int] | None = None, +) -> int: + """Import a list of table configs, returning total row count.""" total = 0 - for table_config in TABLES: + for table_config in table_list: csv_path = csv_dir / table_config["csv_file"] if not csv_path.exists(): logger.warning(f"Skipping {table_config['csv_file']} (not found)") @@ -292,22 +364,80 @@ def main(): table_config["required"], table_config["transforms"], unique_key=table_config.get("unique_key"), + release_id_filter=release_id_filter, ) total += count + return total - # Import artwork from release_image.csv - import_artwork(conn, csv_dir) - # Populate cache_metadata - logger.info("Populating cache_metadata...") - with conn.cursor() as cur: - cur.execute(""" - INSERT INTO cache_metadata (release_id, source) - SELECT id, 'bulk_import' - FROM release - ON CONFLICT (release_id) DO NOTHING - """) - conn.commit() +def main(): + import argparse + + parser = argparse.ArgumentParser(description="Import Discogs CSV files into PostgreSQL") + parser.add_argument("csv_dir", type=Path, help="Directory containing CSV files") + parser.add_argument( + "db_url", + nargs="?", + default="postgresql:///discogs", + help="PostgreSQL connection URL", + ) + + mode = parser.add_mutually_exclusive_group() + mode.add_argument( + "--base-only", + action="store_true", + help="Import only base tables (release, release_artist) " + "plus artwork, cache_metadata, and track counts", + ) + mode.add_argument( + "--tracks-only", + action="store_true", + help="Import only track tables, filtered to surviving release IDs", + ) + + args = parser.parse_args() + csv_dir = args.csv_dir + db_url = args.db_url + + if not csv_dir.exists(): + logger.error(f"CSV directory not found: {csv_dir}") + sys.exit(1) + + logger.info(f"Connecting to {db_url}") + conn = psycopg.connect(db_url) + + if args.tracks_only: + # Query surviving release IDs from the database + with conn.cursor() as cur: + cur.execute("SELECT id FROM release") + release_ids = {row[0] for row in cur.fetchall()} + logger.info(f"Filtering tracks to {len(release_ids):,} surviving releases") + total = _import_tables(conn, csv_dir, TRACK_TABLES, release_id_filter=release_ids) + elif args.base_only: + total = _import_tables(conn, csv_dir, BASE_TABLES) + import_artwork(conn, csv_dir) + logger.info("Populating cache_metadata...") + with conn.cursor() as cur: + cur.execute(""" + INSERT INTO cache_metadata (release_id, source) + SELECT id, 'bulk_import' + FROM release + ON CONFLICT (release_id) DO NOTHING + """) + conn.commit() + create_track_count_table(conn, csv_dir) + else: + total = _import_tables(conn, csv_dir, TABLES) + import_artwork(conn, csv_dir) + logger.info("Populating cache_metadata...") + with conn.cursor() as cur: + cur.execute(""" + INSERT INTO cache_metadata (release_id, source) + SELECT id, 'bulk_import' + FROM release + ON CONFLICT (release_id) DO NOTHING + """) + conn.commit() logger.info(f"Total: {total:,} rows imported") conn.close() diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 2b8b21b..9181ec3 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -446,7 +446,8 @@ def _run_database_build( state: PipelineState | None = None, state_file: Path | None = None, ) -> None: - """Steps 4-9: database schema, import, indexes, dedup, prune/copy-to, vacuum. + """Steps 4-11: schema, base import, base indexes, dedup, track import, + track indexes, prune/copy-to, vacuum. When *state* is provided, completed steps are skipped and progress is saved to *state_file* after each step. @@ -472,19 +473,19 @@ def _save_state() -> None: state.mark_completed("create_schema") _save_state() - # Step 6: Import CSVs + # Step 6: Import base CSVs (release, release_artist, artwork, cache_metadata, track counts) if state and state.is_completed("import_csv"): logger.info("Skipping import_csv (already completed)") else: run_step( - "Import CSVs", - [python, str(SCRIPT_DIR / "import_csv.py"), str(csv_dir), db_url], + "Import base CSVs", + [python, str(SCRIPT_DIR / "import_csv.py"), "--base-only", str(csv_dir), db_url], ) if state: state.mark_completed("import_csv") _save_state() - # Step 7: Create trigram indexes (strip CONCURRENTLY for fresh DB) + # Step 7: Create base trigram indexes (strip CONCURRENTLY for fresh DB) if state and state.is_completed("create_indexes"): logger.info("Skipping create_indexes (already completed)") else: @@ -505,7 +506,28 @@ def _save_state() -> None: state.mark_completed("dedup") _save_state() - # Step 9: Prune or copy-to (optional) + # Step 9: Import tracks (filtered to surviving release IDs) + if state and state.is_completed("import_tracks"): + logger.info("Skipping import_tracks (already completed)") + else: + run_step( + "Import tracks", + [python, str(SCRIPT_DIR / "import_csv.py"), "--tracks-only", str(csv_dir), db_url], + ) + if state: + state.mark_completed("import_tracks") + _save_state() + + # Step 10: Create track indexes (FK constraints, FK indexes, trigram indexes) + if state and state.is_completed("create_track_indexes"): + logger.info("Skipping create_track_indexes (already completed)") + else: + run_sql_file(db_url, SCHEMA_DIR / "create_track_indexes.sql", strip_concurrently=True) + if state: + state.mark_completed("create_track_indexes") + _save_state() + + # Step 11: Prune or copy-to (optional) if state and state.is_completed("prune"): logger.info("Skipping prune/copy-to (already completed)") elif library_db and target_db_url: @@ -537,7 +559,7 @@ def _save_state() -> None: state.mark_completed("prune") _save_state() - # Step 10: Vacuum (on target DB if using copy-to, otherwise source) + # Step 12: Vacuum (on target DB if using copy-to, otherwise source) vacuum_db = target_db_url if target_db_url else db_url if state and state.is_completed("vacuum"): logger.info("Skipping vacuum (already completed)") @@ -547,7 +569,7 @@ def _save_state() -> None: state.mark_completed("vacuum") _save_state() - # Step 11: Report + # Step 13: Report report_sizes(vacuum_db) diff --git a/tests/e2e/test_pipeline.py b/tests/e2e/test_pipeline.py index cca3a24..f49defa 100644 --- a/tests/e2e/test_pipeline.py +++ b/tests/e2e/test_pipeline.py @@ -450,7 +450,7 @@ def test_all_steps_completed(self) -> None: def test_state_file_has_correct_metadata(self) -> None: """State file contains correct database URL and version.""" data = json.loads(self.__class__._state_file.read_text()) - assert data["version"] == 1 + assert data["version"] == 2 assert data["database_url"] == self.__class__._db_url @@ -521,6 +521,8 @@ def test_resume_skips_all_steps(self) -> None: assert "Skipping import_csv" in stderr assert "Skipping create_indexes" in stderr assert "Skipping dedup" in stderr + assert "Skipping import_tracks" in stderr + assert "Skipping create_track_indexes" in stderr assert "Skipping prune" in stderr assert "Skipping vacuum" in stderr diff --git a/tests/integration/test_db_introspect.py b/tests/integration/test_db_introspect.py index 9990c5c..b64fc0f 100644 --- a/tests/integration/test_db_introspect.py +++ b/tests/integration/test_db_introspect.py @@ -12,10 +12,12 @@ import pytest from lib.db_introspect import ( + base_trigram_indexes_exist, column_exists, infer_pipeline_state, table_exists, table_has_rows, + track_trigram_indexes_exist, trigram_indexes_exist, ) @@ -107,8 +109,56 @@ def test_false_after_column_dropped(self, db_url) -> None: assert not column_exists(db_url, "release", "master_id") +class TestBaseTrigramIndexesExist: + """base_trigram_indexes_exist() detects base GIN trigram indexes.""" + + def test_false_before_index_creation(self, db_url) -> None: + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + conn.close() + assert not base_trigram_indexes_exist(db_url) + + def test_true_after_base_index_creation(self, db_url) -> None: + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute(SCHEMA_DIR.joinpath("create_functions.sql").read_text()) + sql = SCHEMA_DIR.joinpath("create_indexes.sql").read_text() + sql = sql.replace(" CONCURRENTLY", "") + cur.execute(sql) + conn.close() + assert base_trigram_indexes_exist(db_url) + + +class TestTrackTrigramIndexesExist: + """track_trigram_indexes_exist() detects track GIN trigram indexes.""" + + def test_false_before_index_creation(self, db_url) -> None: + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + conn.close() + assert not track_trigram_indexes_exist(db_url) + + def test_true_after_track_index_creation(self, db_url) -> None: + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute(SCHEMA_DIR.joinpath("create_functions.sql").read_text()) + sql = SCHEMA_DIR.joinpath("create_track_indexes.sql").read_text() + sql = sql.replace(" CONCURRENTLY", "") + cur.execute(sql) + conn.close() + assert track_trigram_indexes_exist(db_url) + + class TestTrigramIndexesExist: - """trigram_indexes_exist() detects GIN trigram indexes.""" + """trigram_indexes_exist() detects all GIN trigram indexes (backward compat).""" def test_false_before_index_creation(self, db_url) -> None: _clean_db(db_url) @@ -127,6 +177,9 @@ def test_true_after_index_creation(self, db_url) -> None: sql = SCHEMA_DIR.joinpath("create_indexes.sql").read_text() sql = sql.replace(" CONCURRENTLY", "") cur.execute(sql) + sql = SCHEMA_DIR.joinpath("create_track_indexes.sql").read_text() + sql = sql.replace(" CONCURRENTLY", "") + cur.execute(sql) conn.close() assert trigram_indexes_exist(db_url) @@ -137,7 +190,16 @@ class TestInferPipelineState: def test_empty_db(self, db_url) -> None: _clean_db(db_url) state = infer_pipeline_state(db_url) - for step in ["create_schema", "import_csv", "create_indexes", "dedup", "prune", "vacuum"]: + for step in [ + "create_schema", + "import_csv", + "create_indexes", + "dedup", + "import_tracks", + "create_track_indexes", + "prune", + "vacuum", + ]: assert not state.is_completed(step) def test_after_schema_creation(self, db_url) -> None: @@ -164,7 +226,7 @@ def test_after_import(self, db_url) -> None: assert state.is_completed("import_csv") assert not state.is_completed("create_indexes") - def test_after_indexes(self, db_url) -> None: + def test_after_base_indexes(self, db_url) -> None: _clean_db(db_url) conn = psycopg.connect(db_url, autocommit=True) with conn.cursor() as cur: @@ -198,6 +260,55 @@ def test_after_dedup(self, db_url) -> None: state = infer_pipeline_state(db_url) assert state.is_completed("dedup") + assert not state.is_completed("import_tracks") + assert not state.is_completed("create_track_indexes") + + def test_after_track_import(self, db_url) -> None: + """After track import, release_track has rows.""" + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute(SCHEMA_DIR.joinpath("create_functions.sql").read_text()) + cur.execute("INSERT INTO release (id, title) VALUES (1, 'Test')") + sql = SCHEMA_DIR.joinpath("create_indexes.sql").read_text() + sql = sql.replace(" CONCURRENTLY", "") + cur.execute(sql) + cur.execute("ALTER TABLE release DROP COLUMN master_id") + cur.execute( + "INSERT INTO release_track (release_id, sequence, title) VALUES (1, 1, 'T1')" + ) + conn.close() + + state = infer_pipeline_state(db_url) + assert state.is_completed("dedup") + assert state.is_completed("import_tracks") + assert not state.is_completed("create_track_indexes") + + def test_after_track_indexes(self, db_url) -> None: + """After track indexes, all track trigram indexes exist.""" + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute(SCHEMA_DIR.joinpath("create_functions.sql").read_text()) + cur.execute("INSERT INTO release (id, title) VALUES (1, 'Test')") + sql = SCHEMA_DIR.joinpath("create_indexes.sql").read_text() + sql = sql.replace(" CONCURRENTLY", "") + cur.execute(sql) + cur.execute("ALTER TABLE release DROP COLUMN master_id") + cur.execute( + "INSERT INTO release_track (release_id, sequence, title) VALUES (1, 1, 'T1')" + ) + sql = SCHEMA_DIR.joinpath("create_track_indexes.sql").read_text() + sql = sql.replace(" CONCURRENTLY", "") + cur.execute(sql) + conn.close() + + state = infer_pipeline_state(db_url) + assert state.is_completed("dedup") + assert state.is_completed("import_tracks") + assert state.is_completed("create_track_indexes") # prune and vacuum are never inferred assert not state.is_completed("prune") assert not state.is_completed("vacuum") diff --git a/tests/integration/test_dedup.py b/tests/integration/test_dedup.py index f8446a3..fb39268 100644 --- a/tests/integration/test_dedup.py +++ b/tests/integration/test_dedup.py @@ -27,10 +27,13 @@ import_csv_func = _ic.import_csv import_artwork = _ic.import_artwork -TABLES = _ic.TABLES +create_track_count_table = _ic.create_track_count_table +BASE_TABLES = _ic.BASE_TABLES +TRACK_TABLES = _ic.TRACK_TABLES ensure_dedup_ids = _dd.ensure_dedup_ids copy_table = _dd.copy_table swap_tables = _dd.swap_tables +add_base_constraints_and_indexes = _dd.add_base_constraints_and_indexes add_constraints_and_indexes = _dd.add_constraints_and_indexes pytestmark = pytest.mark.postgres @@ -51,13 +54,18 @@ def _drop_all_tables(conn) -> None: cur.execute(f"DROP TABLE IF EXISTS {table} CASCADE") # Also drop dedup artifacts cur.execute("DROP TABLE IF EXISTS dedup_delete_ids CASCADE") + cur.execute("DROP TABLE IF EXISTS release_track_count CASCADE") for prefix in ("new_", ""): for table in ALL_TABLES: cur.execute(f"DROP TABLE IF EXISTS {prefix}{table}_old CASCADE") def _fresh_import(db_url: str) -> None: - """Drop everything, apply schema and functions, and import fixture CSVs.""" + """Drop everything, apply schema and functions, and import base fixture CSVs. + + Imports only BASE_TABLES (release, release_artist) plus artwork, cache_metadata, + and the release_track_count table. Track tables are NOT imported (deferred). + """ conn = psycopg.connect(db_url, autocommit=True) _drop_all_tables(conn) with conn.cursor() as cur: @@ -66,7 +74,7 @@ def _fresh_import(db_url: str) -> None: conn.close() conn = psycopg.connect(db_url) - for table_config in TABLES: + for table_config in BASE_TABLES: csv_path = CSV_DIR / table_config["csv_file"] if csv_path.exists(): import_csv_func( @@ -86,14 +94,16 @@ def _fresh_import(db_url: str) -> None: ON CONFLICT (release_id) DO NOTHING """) conn.commit() + create_track_count_table(conn, CSV_DIR) conn.close() def _run_dedup(db_url: str) -> None: - """Run the full dedup pipeline against the database.""" + """Run the dedup pipeline (base tables only) against the database.""" conn = psycopg.connect(db_url, autocommit=True) delete_count = ensure_dedup_ids(conn) if delete_count > 0: + # Only base tables + cache_metadata (no track tables) tables = [ ("release", "new_release", "id, title, release_year, artwork_url", "id"), ( @@ -102,18 +112,6 @@ def _run_dedup(db_url: str) -> None: "release_id, artist_name, extra", "release_id", ), - ( - "release_track", - "new_release_track", - "release_id, sequence, position, title, duration", - "release_id", - ), - ( - "release_track_artist", - "new_release_track_artist", - "release_id, track_sequence, artist_name", - "release_id", - ), ( "cache_metadata", "new_cache_metadata", @@ -125,12 +123,45 @@ def _run_dedup(db_url: str) -> None: for old, new, cols, id_col in tables: copy_table(conn, old, new, cols, id_col) + # Drop FK constraints before swap + with conn.cursor() as cur: + for stmt in [ + "ALTER TABLE release_artist DROP CONSTRAINT IF EXISTS fk_release_artist_release", + "ALTER TABLE cache_metadata DROP CONSTRAINT IF EXISTS fk_cache_metadata_release", + ]: + cur.execute(stmt) + for old, new, _, _ in tables: swap_tables(conn, old, new) - add_constraints_and_indexes(conn) + add_base_constraints_and_indexes(conn) with conn.cursor() as cur: cur.execute("DROP TABLE IF EXISTS dedup_delete_ids") + cur.execute("DROP TABLE IF EXISTS release_track_count") + conn.close() + + +def _import_tracks_after_dedup(db_url: str) -> None: + """Import tracks filtered to surviving release IDs after dedup.""" + conn = psycopg.connect(db_url) + with conn.cursor() as cur: + cur.execute("SELECT id FROM release") + release_ids = {row[0] for row in cur.fetchall()} + + for table_config in TRACK_TABLES: + csv_path = CSV_DIR / table_config["csv_file"] + if csv_path.exists(): + import_csv_func( + conn, + csv_path, + table_config["table"], + table_config["csv_columns"], + table_config["db_columns"], + table_config["required"], + table_config["transforms"], + unique_key=table_config.get("unique_key"), + release_id_filter=release_ids, + ) conn.close() @@ -139,10 +170,11 @@ class TestDedup: @pytest.fixture(autouse=True, scope="class") def _set_up_and_dedup(self, db_url): - """Import fixtures and run dedup (once per class).""" + """Import base fixtures, run dedup, then import tracks.""" self.__class__._db_url = db_url _fresh_import(db_url) _run_dedup(db_url) + _import_tracks_after_dedup(db_url) @pytest.fixture(autouse=True) def _store_url(self): @@ -188,7 +220,7 @@ def test_null_master_id_release_untouched(self) -> None: assert count == 1 def test_child_table_rows_cleaned(self) -> None: - """Deduped releases have their child table rows removed.""" + """Deduped releases have their child table rows removed (not imported).""" conn = self._connect() with conn.cursor() as cur: cur.execute("SELECT count(*) FROM release_artist WHERE release_id = 1001") @@ -200,7 +232,7 @@ def test_child_table_rows_cleaned(self) -> None: assert track_count == 0 def test_kept_release_tracks_preserved(self) -> None: - """The kept release still has its tracks.""" + """The kept release still has its tracks (imported after dedup).""" conn = self._connect() with conn.cursor() as cur: cur.execute("SELECT count(*) FROM release_track WHERE release_id = 1002") @@ -232,8 +264,8 @@ def test_primary_key_recreated(self) -> None: conn.close() assert result is not None - def test_fk_constraints_recreated(self) -> None: - """FK constraints on child tables exist after dedup.""" + def test_base_fk_constraints_recreated(self) -> None: + """FK constraints on base child tables exist after dedup.""" conn = self._connect() with conn.cursor() as cur: cur.execute(""" @@ -243,9 +275,20 @@ def test_fk_constraints_recreated(self) -> None: """) fk_tables = {row[0] for row in cur.fetchall()} conn.close() - expected = {"release_artist", "release_track", "release_track_artist", "cache_metadata"} + expected = {"release_artist", "cache_metadata"} assert expected.issubset(fk_tables) + def test_track_tables_empty_before_track_import(self) -> None: + """Track tables are empty after base-only dedup (verified via count of + deduped releases -- 1001, 1003, 2001 should have no tracks).""" + conn = self._connect() + with conn.cursor() as cur: + # Deduped release 1001 should have no tracks (it was deleted) + cur.execute("SELECT count(*) FROM release_track WHERE release_id = 1001") + count = cur.fetchone()[0] + conn.close() + assert count == 0 + def test_total_release_count_after_dedup(self) -> None: """Total releases: 15 imported - 3 duplicates = 12.""" conn = self._connect() @@ -256,6 +299,20 @@ def test_total_release_count_after_dedup(self) -> None: # 15 imported (7001 skipped), 1001+1003 removed (master 500), 2001 removed (master 600) assert count == 12 + def test_release_track_count_dropped(self) -> None: + """release_track_count table is cleaned up after dedup.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute( + "SELECT EXISTS (" + " SELECT 1 FROM information_schema.tables" + " WHERE table_name = 'release_track_count'" + ")" + ) + exists = cur.fetchone()[0] + conn.close() + assert not exists + class TestDedupNoop: """Verify dedup is a no-op when there are no duplicates.""" @@ -269,12 +326,15 @@ def _set_up(self, db_url): cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) cur.execute("INSERT INTO release (id, title, master_id) VALUES (1, 'A', 100)") cur.execute("INSERT INTO release (id, title, master_id) VALUES (2, 'B', 200)") - cur.execute( - "INSERT INTO release_track (release_id, sequence, title) VALUES (1, 1, 'Track A')" - ) - cur.execute( - "INSERT INTO release_track (release_id, sequence, title) VALUES (2, 1, 'Track B')" - ) + # Use release_track_count instead of release_track for ranking + cur.execute(""" + CREATE UNLOGGED TABLE release_track_count ( + release_id integer PRIMARY KEY, + track_count integer NOT NULL + ) + """) + cur.execute("INSERT INTO release_track_count (release_id, track_count) VALUES (1, 3)") + cur.execute("INSERT INTO release_track_count (release_id, track_count) VALUES (2, 5)") conn.close() @pytest.fixture(autouse=True) @@ -290,3 +350,51 @@ def test_no_duplicates_found(self) -> None: cur.execute("DROP TABLE IF EXISTS dedup_delete_ids") conn.close() assert count == 0 + + +class TestDedupFallback: + """Verify dedup falls back to release_track when release_track_count doesn't exist.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up(self, db_url): + self.__class__._db_url = db_url + conn = psycopg.connect(db_url, autocommit=True) + _drop_all_tables(conn) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + # Two releases with same master_id + cur.execute("INSERT INTO release (id, title, master_id) VALUES (1, 'A', 100)") + cur.execute("INSERT INTO release (id, title, master_id) VALUES (2, 'B', 100)") + # Release 2 has more tracks + cur.execute( + "INSERT INTO release_track (release_id, sequence, title) VALUES (1, 1, 'T1')" + ) + cur.execute( + "INSERT INTO release_track (release_id, sequence, title) VALUES (2, 1, 'T1')" + ) + cur.execute( + "INSERT INTO release_track (release_id, sequence, title) VALUES (2, 2, 'T2')" + ) + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def test_fallback_to_release_track(self) -> None: + """Without release_track_count, ensure_dedup_ids uses release_track.""" + conn = psycopg.connect(self.db_url, autocommit=True) + count = ensure_dedup_ids(conn) + conn.close() + # Release 1 should be marked for deletion (fewer tracks) + assert count == 1 + + def test_correct_release_deleted(self) -> None: + """Release 1 (1 track) is deleted, release 2 (2 tracks) is kept.""" + conn = psycopg.connect(self.db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute("SELECT release_id FROM dedup_delete_ids") + ids = [row[0] for row in cur.fetchall()] + cur.execute("DROP TABLE IF EXISTS dedup_delete_ids") + conn.close() + assert ids == [1] diff --git a/tests/integration/test_import.py b/tests/integration/test_import.py index 2f40cd7..fabdc53 100644 --- a/tests/integration/test_import.py +++ b/tests/integration/test_import.py @@ -21,7 +21,10 @@ import_csv_func = _ic.import_csv import_artwork = _ic.import_artwork +create_track_count_table = _ic.create_track_count_table TABLES = _ic.TABLES +BASE_TABLES = _ic.BASE_TABLES +TRACK_TABLES = _ic.TRACK_TABLES pytestmark = pytest.mark.postgres @@ -201,3 +204,190 @@ def test_cache_metadata_source(self) -> None: sources = {row[0] for row in cur.fetchall()} conn.close() assert sources == {"bulk_import"} + + +ALL_TABLES = ( + "cache_metadata", + "release_track_artist", + "release_track", + "release_artist", + "release", +) + + +def _clean_db(db_url: str) -> None: + """Drop all pipeline tables and artifacts.""" + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + for table in ALL_TABLES: + cur.execute(f"DROP TABLE IF EXISTS {table} CASCADE") + cur.execute("DROP TABLE IF EXISTS release_track_count CASCADE") + conn.close() + + +class TestTrackCountTable: + """Verify create_track_count_table() creates the right data.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up_database(self, db_url): + """Apply schema, import base tables, then create track count table.""" + self.__class__._db_url = db_url + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + conn.close() + + conn = psycopg.connect(db_url) + for table_config in BASE_TABLES: + csv_path = CSV_DIR / table_config["csv_file"] + if csv_path.exists(): + import_csv_func( + conn, + csv_path, + table_config["table"], + table_config["csv_columns"], + table_config["db_columns"], + table_config["required"], + table_config["transforms"], + ) + create_track_count_table(conn, CSV_DIR) + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def _connect(self): + return psycopg.connect(self.db_url) + + def test_table_exists(self) -> None: + conn = self._connect() + with conn.cursor() as cur: + cur.execute( + "SELECT EXISTS (" + " SELECT 1 FROM information_schema.tables" + " WHERE table_name = 'release_track_count'" + ")" + ) + exists = cur.fetchone()[0] + conn.close() + assert exists + + def test_row_count(self) -> None: + """One row per release_id that has tracks in the CSV.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM release_track_count") + count = cur.fetchone()[0] + conn.close() + # 15 distinct release_ids in release_track.csv + assert count == 15 + + def test_correct_counts(self) -> None: + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT track_count FROM release_track_count WHERE release_id = 1002") + count = cur.fetchone()[0] + conn.close() + assert count == 5 + + def test_track_tables_empty(self) -> None: + """Base-only import should not populate track tables.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM release_track") + count = cur.fetchone()[0] + conn.close() + assert count == 0 + + +class TestFilteredTrackImport: + """Import tracks filtered to a subset of release IDs.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up_database(self, db_url): + """Import base tables, then import tracks filtered to a subset.""" + self.__class__._db_url = db_url + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + conn.close() + + conn = psycopg.connect(db_url) + # Import base tables + for table_config in BASE_TABLES: + csv_path = CSV_DIR / table_config["csv_file"] + if csv_path.exists(): + import_csv_func( + conn, + csv_path, + table_config["table"], + table_config["csv_columns"], + table_config["db_columns"], + table_config["required"], + table_config["transforms"], + ) + + # Import tracks filtered to only a subset of releases + filter_ids = {1002, 3001, 4001} + for table_config in TRACK_TABLES: + csv_path = CSV_DIR / table_config["csv_file"] + if csv_path.exists(): + import_csv_func( + conn, + csv_path, + table_config["table"], + table_config["csv_columns"], + table_config["db_columns"], + table_config["required"], + table_config["transforms"], + unique_key=table_config.get("unique_key"), + release_id_filter=filter_ids, + ) + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def _connect(self): + return psycopg.connect(self.db_url) + + def test_only_filtered_tracks_imported(self) -> None: + """Only tracks for the filtered release IDs should be present.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT DISTINCT release_id FROM release_track ORDER BY release_id") + ids = [row[0] for row in cur.fetchall()] + conn.close() + assert ids == [1002, 3001, 4001] + + def test_excluded_release_has_no_tracks(self) -> None: + """Releases not in the filter set should have no tracks.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM release_track WHERE release_id = 1001") + count = cur.fetchone()[0] + conn.close() + assert count == 0 + + def test_included_release_has_correct_track_count(self) -> None: + """Release 1002 should have all 5 tracks.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM release_track WHERE release_id = 1002") + count = cur.fetchone()[0] + conn.close() + assert count == 5 + + def test_total_track_count(self) -> None: + """Total tracks should be the sum for the filtered releases.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM release_track") + count = cur.fetchone()[0] + conn.close() + # 1002: 5, 3001: 2, 4001: 2 = 9 + assert count == 9 diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index 18b0eb6..b874512 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -159,8 +159,8 @@ def test_schema_is_idempotent(self) -> None: conn.close() -class TestCreateIndexes: - """Verify create_indexes.sql can be applied after data import.""" +class TestCreateBaseIndexes: + """Verify create_indexes.sql creates base trigram indexes.""" @pytest.fixture(autouse=True) def _apply_schema_and_data(self, db_url): @@ -170,7 +170,6 @@ def _apply_schema_and_data(self, db_url): with conn.cursor() as cur: cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) cur.execute(SCHEMA_DIR.joinpath("create_functions.sql").read_text()) - # Insert minimal data so indexes have something to work with cur.execute( "INSERT INTO release (id, title) VALUES (1, 'Test Album') ON CONFLICT DO NOTHING" ) @@ -179,6 +178,73 @@ def _apply_schema_and_data(self, db_url): "SELECT 1, 'Test Artist', 0 WHERE NOT EXISTS " "(SELECT 1 FROM release_artist WHERE release_id = 1)" ) + conn.close() + + def test_base_indexes_execute_without_error(self) -> None: + """Base trigram indexes can be created after data is loaded.""" + conn = psycopg.connect(self.db_url, autocommit=True) + with conn.cursor() as cur: + sql = SCHEMA_DIR.joinpath("create_indexes.sql").read_text() + sql = sql.replace(" CONCURRENTLY", "") + cur.execute(sql) + conn.close() + + def test_base_trigram_indexes_exist(self) -> None: + """Base trigram indexes (release, release_artist) are created.""" + conn = psycopg.connect(self.db_url, autocommit=True) + with conn.cursor() as cur: + sql = SCHEMA_DIR.joinpath("create_indexes.sql").read_text() + sql = sql.replace(" CONCURRENTLY", "") + cur.execute(sql) + + cur.execute(""" + SELECT indexname FROM pg_indexes + WHERE schemaname = 'public' + AND indexname LIKE '%trgm%' + """) + indexes = {row[0] for row in cur.fetchall()} + conn.close() + expected = { + "idx_release_artist_name_trgm", + "idx_release_title_trgm", + } + assert expected.issubset(indexes) + + def test_base_trigram_indexes_use_unaccent(self) -> None: + """Base trigram indexes use f_unaccent() for accent-insensitive matching.""" + conn = psycopg.connect(self.db_url, autocommit=True) + with conn.cursor() as cur: + sql = SCHEMA_DIR.joinpath("create_indexes.sql").read_text() + sql = sql.replace(" CONCURRENTLY", "") + cur.execute(sql) + + cur.execute(""" + SELECT indexname, indexdef FROM pg_indexes + WHERE schemaname = 'public' + AND indexname LIKE '%trgm%' + """) + rows = cur.fetchall() + conn.close() + for indexname, indexdef in rows: + assert "f_unaccent" in indexdef, ( + f"Index {indexname} should use f_unaccent(): {indexdef}" + ) + + +class TestCreateTrackIndexes: + """Verify create_track_indexes.sql creates track-related indexes and constraints.""" + + @pytest.fixture(autouse=True) + def _apply_schema_and_data(self, db_url): + """Set up schema, functions, and insert minimal sample data for track indexes.""" + self.db_url = db_url + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute(SCHEMA_DIR.joinpath("create_functions.sql").read_text()) + cur.execute( + "INSERT INTO release (id, title) VALUES (1, 'Test Album') ON CONFLICT DO NOTHING" + ) cur.execute( "INSERT INTO release_track (release_id, sequence, position, title) " "SELECT 1, 1, 'A1', 'Test Track' WHERE NOT EXISTS " @@ -191,21 +257,20 @@ def _apply_schema_and_data(self, db_url): ) conn.close() - def test_indexes_execute_without_error(self) -> None: - """Trigram indexes can be created after data is loaded.""" + def test_track_indexes_execute_without_error(self) -> None: + """Track indexes can be created after track data is loaded.""" conn = psycopg.connect(self.db_url, autocommit=True) with conn.cursor() as cur: - # Strip CONCURRENTLY since we're in a test with autocommit - sql = SCHEMA_DIR.joinpath("create_indexes.sql").read_text() + sql = SCHEMA_DIR.joinpath("create_track_indexes.sql").read_text() sql = sql.replace(" CONCURRENTLY", "") cur.execute(sql) conn.close() - def test_trigram_indexes_exist(self) -> None: - """All four trigram indexes are created.""" + def test_track_trigram_indexes_exist(self) -> None: + """Track trigram indexes (release_track, release_track_artist) are created.""" conn = psycopg.connect(self.db_url, autocommit=True) with conn.cursor() as cur: - sql = SCHEMA_DIR.joinpath("create_indexes.sql").read_text() + sql = SCHEMA_DIR.joinpath("create_track_indexes.sql").read_text() sql = sql.replace(" CONCURRENTLY", "") cur.execute(sql) @@ -218,28 +283,86 @@ def test_trigram_indexes_exist(self) -> None: conn.close() expected = { "idx_release_track_title_trgm", - "idx_release_artist_name_trgm", "idx_release_track_artist_name_trgm", - "idx_release_title_trgm", } assert expected.issubset(indexes) - def test_trigram_indexes_use_unaccent(self) -> None: - """All four trigram indexes use f_unaccent() for accent-insensitive matching.""" + def test_track_fk_constraints_created(self) -> None: + """FK constraints on track tables are created.""" conn = psycopg.connect(self.db_url, autocommit=True) with conn.cursor() as cur: - sql = SCHEMA_DIR.joinpath("create_indexes.sql").read_text() + sql = SCHEMA_DIR.joinpath("create_track_indexes.sql").read_text() sql = sql.replace(" CONCURRENTLY", "") cur.execute(sql) cur.execute(""" - SELECT indexname, indexdef FROM pg_indexes + SELECT tc.constraint_name + FROM information_schema.table_constraints tc + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_name IN ('release_track', 'release_track_artist') + """) + constraints = {row[0] for row in cur.fetchall()} + conn.close() + expected = {"fk_release_track_release", "fk_release_track_artist_release"} + assert expected.issubset(constraints) + + def test_track_fk_indexes_created(self) -> None: + """FK indexes on track tables are created.""" + conn = psycopg.connect(self.db_url, autocommit=True) + with conn.cursor() as cur: + sql = SCHEMA_DIR.joinpath("create_track_indexes.sql").read_text() + sql = sql.replace(" CONCURRENTLY", "") + cur.execute(sql) + + cur.execute(""" + SELECT indexname FROM pg_indexes WHERE schemaname = 'public' - AND indexname LIKE '%trgm%' + AND indexname LIKE 'idx_release_track%' """) - rows = cur.fetchall() + indexes = {row[0] for row in cur.fetchall()} conn.close() - for indexname, indexdef in rows: - assert "f_unaccent" in indexdef, ( - f"Index {indexname} should use f_unaccent(): {indexdef}" + expected = { + "idx_release_track_release_id", + "idx_release_track_artist_release_id", + } + assert expected.issubset(indexes) + + def test_track_indexes_idempotent(self) -> None: + """Running create_track_indexes.sql twice doesn't error.""" + conn = psycopg.connect(self.db_url, autocommit=True) + with conn.cursor() as cur: + sql = SCHEMA_DIR.joinpath("create_track_indexes.sql").read_text() + sql = sql.replace(" CONCURRENTLY", "") + cur.execute(sql) + cur.execute(sql) + conn.close() + + def test_all_trigram_indexes_after_both_sql_files(self) -> None: + """All four trigram indexes exist after running both SQL files.""" + conn = psycopg.connect(self.db_url, autocommit=True) + with conn.cursor() as cur: + # Need release_artist data for base indexes + cur.execute( + "INSERT INTO release_artist (release_id, artist_name, extra) " + "SELECT 1, 'Test Artist', 0 WHERE NOT EXISTS " + "(SELECT 1 FROM release_artist WHERE release_id = 1)" ) + for sql_file in ("create_indexes.sql", "create_track_indexes.sql"): + sql = SCHEMA_DIR.joinpath(sql_file).read_text() + sql = sql.replace(" CONCURRENTLY", "") + cur.execute(sql) + + cur.execute(""" + SELECT indexname FROM pg_indexes + WHERE schemaname = 'public' + AND indexname LIKE '%trgm%' + """) + indexes = {row[0] for row in cur.fetchall()} + conn.close() + expected = { + "idx_release_artist_name_trgm", + "idx_release_title_trgm", + "idx_release_track_title_trgm", + "idx_release_track_artist_name_trgm", + } + assert expected.issubset(indexes) diff --git a/tests/unit/test_import_csv.py b/tests/unit/test_import_csv.py index f55afa1..1701191 100644 --- a/tests/unit/test_import_csv.py +++ b/tests/unit/test_import_csv.py @@ -15,9 +15,15 @@ _spec.loader.exec_module(_ic) extract_year = _ic.extract_year +count_tracks_from_csv = _ic.count_tracks_from_csv TABLES = _ic.TABLES +BASE_TABLES = _ic.BASE_TABLES +TRACK_TABLES = _ic.TRACK_TABLES TableConfig = _ic.TableConfig +FIXTURES_DIR = Path(__file__).parent.parent / "fixtures" +CSV_DIR = FIXTURES_DIR / "csv" + # --------------------------------------------------------------------------- # extract_year @@ -171,3 +177,68 @@ def test_release_artist_csv_has_expected_columns(self) -> None: assert col in headers, ( f"Expected column {col!r} not in release_artist.csv headers: {headers}" ) + + +# --------------------------------------------------------------------------- +# count_tracks_from_csv +# --------------------------------------------------------------------------- + + +class TestCountTracksFromCsv: + """Count tracks per release_id from a release_track CSV file.""" + + def test_counts_tracks_per_release(self) -> None: + """Returns a dict mapping release_id -> track count.""" + csv_path = CSV_DIR / "release_track.csv" + counts = count_tracks_from_csv(csv_path) + # Release 1002 has 5 tracks in the fixture data + assert counts[1002] == 5 + + def test_all_releases_counted(self) -> None: + """Every release_id in the CSV has an entry.""" + csv_path = CSV_DIR / "release_track.csv" + counts = count_tracks_from_csv(csv_path) + assert len(counts) == 15 + + def test_returns_empty_for_nonexistent_file(self, tmp_path) -> None: + """Returns empty dict when file is empty (only header).""" + csv_path = tmp_path / "empty.csv" + csv_path.write_text("release_id,sequence,position,title,duration\n") + counts = count_tracks_from_csv(csv_path) + assert counts == {} + + def test_skips_invalid_release_ids(self, tmp_path) -> None: + """Rows with non-integer release_id are skipped.""" + csv_path = tmp_path / "bad.csv" + csv_path.write_text( + "release_id,sequence,position,title,duration\n" + "abc,1,A1,Track 1,3:00\n" + "1,1,A1,Track 1,3:00\n" + ) + counts = count_tracks_from_csv(csv_path) + assert counts == {1: 1} + + +# --------------------------------------------------------------------------- +# BASE_TABLES / TRACK_TABLES split +# --------------------------------------------------------------------------- + + +class TestTableSplit: + """TABLES is the union of BASE_TABLES and TRACK_TABLES.""" + + def test_tables_is_union(self) -> None: + assert TABLES == BASE_TABLES + TRACK_TABLES + + def test_base_tables_are_release_and_release_artist(self) -> None: + names = [t["table"] for t in BASE_TABLES] + assert names == ["release", "release_artist"] + + def test_track_tables_are_release_track_and_release_track_artist(self) -> None: + names = [t["table"] for t in TRACK_TABLES] + assert names == ["release_track", "release_track_artist"] + + def test_no_overlap(self) -> None: + base_names = {t["table"] for t in BASE_TABLES} + track_names = {t["table"] for t in TRACK_TABLES} + assert base_names.isdisjoint(track_names) diff --git a/tests/unit/test_pipeline_state.py b/tests/unit/test_pipeline_state.py index 42e6999..9e3541d 100644 --- a/tests/unit/test_pipeline_state.py +++ b/tests/unit/test_pipeline_state.py @@ -6,9 +6,9 @@ import pytest -from lib.pipeline_state import PipelineState +from lib.pipeline_state import STEP_NAMES, PipelineState -STEPS = ["create_schema", "import_csv", "create_indexes", "dedup", "prune", "vacuum"] +STEPS = STEP_NAMES class TestFreshState: @@ -22,6 +22,23 @@ def test_no_steps_completed(self) -> None: for step in STEPS: assert not state.is_completed(step) + def test_step_count(self) -> None: + """V2 pipeline has 8 steps.""" + assert len(STEPS) == 8 + + def test_step_order(self) -> None: + """Steps are in correct execution order.""" + assert STEPS == [ + "create_schema", + "import_csv", + "create_indexes", + "dedup", + "import_tracks", + "create_track_indexes", + "prune", + "vacuum", + ] + class TestMarkCompleted: """mark_completed() / is_completed() round-trip.""" @@ -37,6 +54,13 @@ def test_other_steps_remain_pending(self) -> None: assert not state.is_completed("create_schema") assert not state.is_completed("dedup") + def test_new_steps_can_be_marked(self) -> None: + state = PipelineState(db_url="postgresql://localhost/test", csv_dir="/tmp/csv") + state.mark_completed("import_tracks") + assert state.is_completed("import_tracks") + state.mark_completed("create_track_indexes") + assert state.is_completed("create_track_indexes") + class TestMarkFailed: """mark_failed() records error message.""" @@ -59,7 +83,7 @@ def test_save_creates_valid_json(self, tmp_path) -> None: state.save(state_file) data = json.loads(state_file.read_text()) - assert data["version"] == 1 + assert data["version"] == 2 assert data["database_url"] == "postgresql://localhost/test" assert data["csv_dir"] == "/tmp/csv" assert data["steps"]["create_schema"]["status"] == "completed" @@ -92,6 +116,112 @@ def test_save_is_atomic(self, tmp_path) -> None: tmp_files = list(tmp_path.glob("*.tmp")) assert tmp_files == [] + def test_v2_round_trip(self, tmp_path) -> None: + """Save and load a v2 state with all steps.""" + state_file = tmp_path / "state.json" + state = PipelineState(db_url="postgresql://localhost/test", csv_dir="/tmp/csv") + for step in STEPS: + state.mark_completed(step) + state.save(state_file) + + loaded = PipelineState.load(state_file) + for step in STEPS: + assert loaded.is_completed(step), f"Step {step} should be completed" + + def test_v2_has_all_steps_in_file(self, tmp_path) -> None: + """State file contains all 8 v2 steps.""" + state_file = tmp_path / "state.json" + state = PipelineState(db_url="postgresql://localhost/test", csv_dir="/tmp/csv") + state.save(state_file) + + data = json.loads(state_file.read_text()) + assert set(data["steps"].keys()) == set(STEPS) + + +class TestV1Migration: + """load() migrates v1 state files to v2.""" + + def _make_v1_state(self, tmp_path, completed_steps: list[str]) -> dict: + """Create a v1 state file and return its data.""" + v1_steps = { + name: {"status": "pending"} + for name in [ + "create_schema", + "import_csv", + "create_indexes", + "dedup", + "prune", + "vacuum", + ] + } + for step in completed_steps: + v1_steps[step] = {"status": "completed"} + + data = { + "version": 1, + "database_url": "postgresql://localhost/test", + "csv_dir": "/tmp/csv", + "steps": v1_steps, + } + state_file = tmp_path / "state.json" + state_file.write_text(json.dumps(data)) + return data + + def test_all_completed_v1(self, tmp_path) -> None: + """All v1 steps completed -> all v2 steps completed.""" + self._make_v1_state( + tmp_path, + ["create_schema", "import_csv", "create_indexes", "dedup", "prune", "vacuum"], + ) + state = PipelineState.load(tmp_path / "state.json") + + for step in STEPS: + assert state.is_completed(step), f"Step {step} should be completed after v1 migration" + + def test_import_csv_completed_implies_import_tracks(self, tmp_path) -> None: + """V1 import_csv completed -> import_tracks also completed.""" + self._make_v1_state(tmp_path, ["create_schema", "import_csv"]) + state = PipelineState.load(tmp_path / "state.json") + + assert state.is_completed("import_csv") + assert state.is_completed("import_tracks") + assert not state.is_completed("create_track_indexes") + + def test_create_indexes_completed_implies_create_track_indexes(self, tmp_path) -> None: + """V1 create_indexes completed -> create_track_indexes also completed.""" + self._make_v1_state(tmp_path, ["create_schema", "import_csv", "create_indexes"]) + state = PipelineState.load(tmp_path / "state.json") + + assert state.is_completed("create_indexes") + assert state.is_completed("create_track_indexes") + + def test_dedup_completed_implies_create_track_indexes(self, tmp_path) -> None: + """V1 dedup completed -> create_track_indexes also completed.""" + self._make_v1_state(tmp_path, ["create_schema", "import_csv", "create_indexes", "dedup"]) + state = PipelineState.load(tmp_path / "state.json") + + assert state.is_completed("dedup") + assert state.is_completed("import_tracks") + assert state.is_completed("create_track_indexes") + + def test_partial_v1_only_schema(self, tmp_path) -> None: + """V1 with only schema completed.""" + self._make_v1_state(tmp_path, ["create_schema"]) + state = PipelineState.load(tmp_path / "state.json") + + assert state.is_completed("create_schema") + assert not state.is_completed("import_csv") + assert not state.is_completed("import_tracks") + assert not state.is_completed("create_track_indexes") + + def test_v1_preserves_metadata(self, tmp_path) -> None: + """V1 migration preserves db_url and csv_dir.""" + self._make_v1_state(tmp_path, []) + state = PipelineState.load(tmp_path / "state.json") + + assert state.db_url == "postgresql://localhost/test" + assert state.csv_dir == "/tmp/csv" + class TestValidateResume: """validate_resume() rejects mismatched db_url or csv_dir.""" From d52400969d72a197f287862c23e6f52b1671cec0 Mon Sep 17 00:00:00 2001 From: Jake Bromberg Date: Sat, 14 Feb 2026 15:45:41 -0800 Subject: [PATCH 2/3] fix: address PR review feedback - Add comment on f-string SQL noting values are trusted internal constants - Rename misleading test_track_tables_empty_before_track_import to test_deduped_release_has_no_tracks - Replace hardcoded step numbers with step name labels in run_pipeline.py comments to avoid renumbering when steps are added or reordered --- scripts/dedup_releases.py | 1 + scripts/run_pipeline.py | 34 ++++++++++++++++----------------- tests/integration/test_dedup.py | 6 ++---- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/scripts/dedup_releases.py b/scripts/dedup_releases.py index 874f7cd..8547b30 100644 --- a/scripts/dedup_releases.py +++ b/scripts/dedup_releases.py @@ -84,6 +84,7 @@ def ensure_dedup_ids(conn) -> int: ) with conn.cursor() as cur: + # track_count_join is built from trusted internal constants, not user input cur.execute(f""" CREATE UNLOGGED TABLE dedup_delete_ids AS SELECT id AS release_id FROM ( diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 9181ec3..4fe6902 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -400,13 +400,13 @@ def main() -> None: cleaned_csv_dir = tmp / "cleaned" filtered_csv_dir = tmp / "filtered" - # Step 1: Convert XML to CSV + # -- convert_xml: Convert XML to CSV convert_xml_to_csv(args.xml, args.xml2db, raw_csv_dir) - # Step 2: Fix CSV newlines + # -- fix_newlines: Fix CSV newlines fix_csv_newlines(raw_csv_dir, cleaned_csv_dir) - # Step 2.5: Enrich library_artists.txt (optional) + # -- enrich_artists: Enrich library_artists.txt (optional) if args.library_db: enriched_artists = tmp / "enriched_library_artists.txt" enrich_library_artists(args.library_db, enriched_artists, args.wxyc_db_url) @@ -414,13 +414,13 @@ def main() -> None: else: library_artists_path = args.library_artists - # Step 3: Filter to library artists + # -- filter_csv: Filter to library artists filter_to_library_artists(library_artists_path, cleaned_csv_dir, filtered_csv_dir) - # Steps 4-9: Database build + # -- database build (create_schema through vacuum) _run_database_build(db_url, filtered_csv_dir, args.library_db, python) else: - # Steps 4-9 only (--csv-dir mode) + # Database build only (--csv-dir mode) state = _load_or_create_state(args) _run_database_build( db_url, @@ -446,8 +446,7 @@ def _run_database_build( state: PipelineState | None = None, state_file: Path | None = None, ) -> None: - """Steps 4-11: schema, base import, base indexes, dedup, track import, - track indexes, prune/copy-to, vacuum. + """Database build: create_schema through vacuum. When *state* is provided, completed steps are skipped and progress is saved to *state_file* after each step. @@ -460,10 +459,9 @@ def _save_state() -> None: if state is not None and state_file is not None: state.save(state_file) - # Step 4: Wait for Postgres wait_for_postgres(db_url) - # Step 5: Create schema and functions + # -- create_schema if state and state.is_completed("create_schema"): logger.info("Skipping create_schema (already completed)") else: @@ -473,7 +471,7 @@ def _save_state() -> None: state.mark_completed("create_schema") _save_state() - # Step 6: Import base CSVs (release, release_artist, artwork, cache_metadata, track counts) + # -- import_csv (base tables, artwork, cache_metadata, track counts) if state and state.is_completed("import_csv"): logger.info("Skipping import_csv (already completed)") else: @@ -485,7 +483,7 @@ def _save_state() -> None: state.mark_completed("import_csv") _save_state() - # Step 7: Create base trigram indexes (strip CONCURRENTLY for fresh DB) + # -- create_indexes (base trigram indexes, strip CONCURRENTLY for fresh DB) if state and state.is_completed("create_indexes"): logger.info("Skipping create_indexes (already completed)") else: @@ -494,7 +492,7 @@ def _save_state() -> None: state.mark_completed("create_indexes") _save_state() - # Step 8: Deduplicate by master_id + # -- dedup (deduplicate by master_id) if state and state.is_completed("dedup"): logger.info("Skipping dedup (already completed)") else: @@ -506,7 +504,7 @@ def _save_state() -> None: state.mark_completed("dedup") _save_state() - # Step 9: Import tracks (filtered to surviving release IDs) + # -- import_tracks (filtered to surviving release IDs) if state and state.is_completed("import_tracks"): logger.info("Skipping import_tracks (already completed)") else: @@ -518,7 +516,7 @@ def _save_state() -> None: state.mark_completed("import_tracks") _save_state() - # Step 10: Create track indexes (FK constraints, FK indexes, trigram indexes) + # -- create_track_indexes (FK constraints, FK indexes, trigram indexes) if state and state.is_completed("create_track_indexes"): logger.info("Skipping create_track_indexes (already completed)") else: @@ -527,7 +525,7 @@ def _save_state() -> None: state.mark_completed("create_track_indexes") _save_state() - # Step 11: Prune or copy-to (optional) + # -- prune (or copy-to, optional) if state and state.is_completed("prune"): logger.info("Skipping prune/copy-to (already completed)") elif library_db and target_db_url: @@ -559,7 +557,7 @@ def _save_state() -> None: state.mark_completed("prune") _save_state() - # Step 12: Vacuum (on target DB if using copy-to, otherwise source) + # -- vacuum (on target DB if using copy-to, otherwise source) vacuum_db = target_db_url if target_db_url else db_url if state and state.is_completed("vacuum"): logger.info("Skipping vacuum (already completed)") @@ -569,7 +567,7 @@ def _save_state() -> None: state.mark_completed("vacuum") _save_state() - # Step 13: Report + # -- report report_sizes(vacuum_db) diff --git a/tests/integration/test_dedup.py b/tests/integration/test_dedup.py index fb39268..7ce51b7 100644 --- a/tests/integration/test_dedup.py +++ b/tests/integration/test_dedup.py @@ -278,12 +278,10 @@ def test_base_fk_constraints_recreated(self) -> None: expected = {"release_artist", "cache_metadata"} assert expected.issubset(fk_tables) - def test_track_tables_empty_before_track_import(self) -> None: - """Track tables are empty after base-only dedup (verified via count of - deduped releases -- 1001, 1003, 2001 should have no tracks).""" + def test_deduped_release_has_no_tracks(self) -> None: + """Releases removed by dedup have no tracks (not imported for them).""" conn = self._connect() with conn.cursor() as cur: - # Deduped release 1001 should have no tracks (it was deleted) cur.execute("SELECT count(*) FROM release_track WHERE release_id = 1001") count = cur.fetchone()[0] conn.close() From 0de6f119d3b00358ee501c3206873bf7dfd39c1b Mon Sep 17 00:00:00 2001 From: Jake Bromberg Date: Sat, 14 Feb 2026 21:51:20 -0800 Subject: [PATCH 3/3] feat: add release_label table to ETL pipeline Add label data (release_id, label_name) as a base table throughout the pipeline: schema, CSV filter, import, dedup copy-swap, verify/copy-to, vacuum, and reporting. Labels follow the same pattern as release_artist -- FK CASCADE child table, imported before dedup, copy-swapped during dedup, streamed during copy-to. Also fix a pre-existing bug where copy-to targets were missing track trigram indexes (_create_target_indexes now runs create_track_indexes.sql in addition to create_indexes.sql). --- schema/create_database.sql | 7 +++++ scripts/dedup_releases.py | 9 +++++-- scripts/filter_csv.py | 3 ++- scripts/import_csv.py | 16 +++++++++--- scripts/run_pipeline.py | 5 ++-- scripts/verify_cache.py | 12 ++++++--- tests/e2e/test_pipeline.py | 24 ++++++++++++++--- tests/fixtures/create_fixtures.py | 39 ++++++++++++++++++++++++++++ tests/fixtures/csv/release_label.csv | 17 ++++++++++++ tests/integration/test_dedup.py | 33 ++++++++++++++++++++++- tests/integration/test_import.py | 22 ++++++++++++++++ tests/integration/test_schema.py | 6 ++++- tests/unit/test_import_csv.py | 22 +++++++++++++--- 13 files changed, 195 insertions(+), 20 deletions(-) create mode 100644 tests/fixtures/csv/release_label.csv diff --git a/schema/create_database.sql b/schema/create_database.sql index 9007ae7..9e74311 100644 --- a/schema/create_database.sql +++ b/schema/create_database.sql @@ -36,6 +36,12 @@ CREATE TABLE IF NOT EXISTS release_artist ( extra integer DEFAULT 0 -- 0 = main artist, 1 = extra credit ); +-- Labels on releases +CREATE TABLE IF NOT EXISTS release_label ( + release_id integer NOT NULL REFERENCES release(id) ON DELETE CASCADE, + label_name text NOT NULL +); + -- Tracks on releases CREATE TABLE IF NOT EXISTS release_track ( release_id integer NOT NULL REFERENCES release(id) ON DELETE CASCADE, @@ -69,6 +75,7 @@ CREATE TABLE IF NOT EXISTS cache_metadata ( -- Foreign key indexes CREATE INDEX IF NOT EXISTS idx_release_artist_release_id ON release_artist(release_id); +CREATE INDEX IF NOT EXISTS idx_release_label_release_id ON release_label(release_id); CREATE INDEX IF NOT EXISTS idx_release_track_release_id ON release_track(release_id); CREATE INDEX IF NOT EXISTS idx_release_track_artist_release_id ON release_track_artist(release_id); diff --git a/scripts/dedup_releases.py b/scripts/dedup_releases.py index 8547b30..16decab 100644 --- a/scripts/dedup_releases.py +++ b/scripts/dedup_releases.py @@ -165,11 +165,14 @@ def add_base_constraints_and_indexes(conn) -> None: # FK constraints with CASCADE (base tables only) "ALTER TABLE release_artist ADD CONSTRAINT fk_release_artist_release " "FOREIGN KEY (release_id) REFERENCES release(id) ON DELETE CASCADE", + "ALTER TABLE release_label ADD CONSTRAINT fk_release_label_release " + "FOREIGN KEY (release_id) REFERENCES release(id) ON DELETE CASCADE", "ALTER TABLE cache_metadata ADD CONSTRAINT fk_cache_metadata_release " "FOREIGN KEY (release_id) REFERENCES release(id) ON DELETE CASCADE", "ALTER TABLE cache_metadata ADD PRIMARY KEY (release_id)", # FK indexes (base tables only) "CREATE INDEX idx_release_artist_release_id ON release_artist(release_id)", + "CREATE INDEX idx_release_label_release_id ON release_label(release_id)", # Base trigram indexes for fuzzy search (accent-insensitive via f_unaccent) "CREATE INDEX idx_release_artist_name_trgm ON release_artist " "USING gin (lower(f_unaccent(artist_name)) gin_trgm_ops)", @@ -264,6 +267,7 @@ def main(): tables = [ ("release", "new_release", "id, title, release_year, artwork_url", "id"), ("release_artist", "new_release_artist", "release_id, artist_name, extra", "release_id"), + ("release_label", "new_release_label", "release_id, label_name", "release_id"), ( "cache_metadata", "new_cache_metadata", @@ -280,6 +284,7 @@ def main(): with conn.cursor() as cur: for stmt in [ "ALTER TABLE release_artist DROP CONSTRAINT IF EXISTS fk_release_artist_release", + "ALTER TABLE release_label DROP CONSTRAINT IF EXISTS fk_release_label_release", "ALTER TABLE cache_metadata DROP CONSTRAINT IF EXISTS fk_cache_metadata_release", ]: cur.execute(stmt) @@ -311,8 +316,8 @@ def main(): cur.execute(""" SELECT relname, pg_size_pretty(pg_total_relation_size(relid)) as total_size FROM pg_stat_user_tables - WHERE relname IN ('release', 'release_artist', 'release_track', - 'release_track_artist', 'cache_metadata') + WHERE relname IN ('release', 'release_artist', 'release_label', + 'release_track', 'release_track_artist', 'cache_metadata') ORDER BY pg_total_relation_size(relid) DESC """) logger.info("Table sizes:") diff --git a/scripts/filter_csv.py b/scripts/filter_csv.py index 40089d5..696155c 100644 --- a/scripts/filter_csv.py +++ b/scripts/filter_csv.py @@ -24,10 +24,11 @@ # CSV files that need to be filtered by release_id. # Only includes files needed by the optimized schema (see 04-create-database.sql). -# Dropped tables (release_label, release_genre, release_style) are excluded. +# Dropped tables (release_genre, release_style) are excluded. RELEASE_ID_FILES = [ "release.csv", "release_artist.csv", + "release_label.csv", "release_track.csv", "release_track_artist.csv", "release_image.csv", # for artwork_url extraction during import diff --git a/scripts/import_csv.py b/scripts/import_csv.py index 66cfc77..e393ca7 100644 --- a/scripts/import_csv.py +++ b/scripts/import_csv.py @@ -2,7 +2,7 @@ """Import Discogs CSV files into PostgreSQL with proper multiline handling. Imports only the columns needed by the optimized schema (see 04-create-database.sql). -Dropped tables (release_label, release_genre, release_style, artist) are skipped. +Dropped tables (release_genre, release_style, artist) are skipped. The release_image.csv is processed separately to populate artwork_url on release. """ @@ -60,13 +60,14 @@ def count_tracks_from_csv(csv_path: Path) -> dict[int, int]: return counts -class TableConfig(TypedDict): +class TableConfig(TypedDict, total=False): csv_file: str table: str csv_columns: list[str] db_columns: list[str] required: list[str] transforms: dict[str, Callable[[str | None], str | None]] + unique_key: list[str] BASE_TABLES: list[TableConfig] = [ @@ -87,6 +88,15 @@ class TableConfig(TypedDict): "transforms": {}, "unique_key": ["release_id", "artist_name"], }, + { + "csv_file": "release_label.csv", + "table": "release_label", + "csv_columns": ["release_id", "label"], + "db_columns": ["release_id", "label_name"], + "required": ["release_id", "label"], + "transforms": {}, + "unique_key": ["release_id", "label"], + }, ] TRACK_TABLES: list[TableConfig] = [ @@ -386,7 +396,7 @@ def main(): mode.add_argument( "--base-only", action="store_true", - help="Import only base tables (release, release_artist) " + help="Import only base tables (release, release_artist, release_label) " "plus artwork, cache_metadata, and track counts", ) mode.add_argument( diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 4fe6902..18cc478 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -233,6 +233,7 @@ def run_vacuum(db_url: str) -> None: tables = [ "release", "release_artist", + "release_label", "release_track", "release_track_artist", "cache_metadata", @@ -260,8 +261,8 @@ def report_sizes(db_url: str) -> None: pg_size_pretty(pg_total_relation_size(relid)) as total_size FROM pg_stat_user_tables WHERE relname IN ( - 'release', 'release_artist', 'release_track', - 'release_track_artist', 'cache_metadata' + 'release', 'release_artist', 'release_label', + 'release_track', 'release_track_artist', 'cache_metadata' ) ORDER BY pg_total_relation_size(relid) DESC """) diff --git a/scripts/verify_cache.py b/scripts/verify_cache.py index c8c8d5c..c60cf16 100644 --- a/scripts/verify_cache.py +++ b/scripts/verify_cache.py @@ -87,6 +87,7 @@ RELEASE_TABLES = [ ("release", "id"), ("release_artist", "release_id"), + ("release_label", "release_id"), ("release_track", "release_id"), ("release_track_artist", "release_id"), ("cache_metadata", "release_id"), @@ -661,6 +662,7 @@ async def prune_releases(conn: asyncpg.Connection, release_ids: set[int]) -> dic COPY_TABLE_SPEC = [ ("release", "id", ["id", "title", "release_year", "artwork_url"]), ("release_artist", "release_id", ["release_id", "artist_name", "extra"]), + ("release_label", "release_id", ["release_id", "label_name"]), ("release_track", "release_id", ["release_id", "sequence", "position", "title", "duration"]), ( "release_track_artist", @@ -717,6 +719,7 @@ def _create_target_schema(target_url: str) -> None: "cache_metadata", "release_track_artist", "release_track", + "release_label", "release_artist", "release", ): @@ -730,13 +733,16 @@ def _create_target_schema(target_url: str) -> None: def _create_target_indexes(target_url: str) -> None: """Create functions and indexes on the target database (without CONCURRENTLY).""" functions_sql = SCHEMA_DIR.joinpath("create_functions.sql").read_text() - indexes_sql = SCHEMA_DIR.joinpath("create_indexes.sql").read_text() - indexes_sql = indexes_sql.replace(" CONCURRENTLY", "") + base_indexes_sql = SCHEMA_DIR.joinpath("create_indexes.sql").read_text() + base_indexes_sql = base_indexes_sql.replace(" CONCURRENTLY", "") + track_indexes_sql = SCHEMA_DIR.joinpath("create_track_indexes.sql").read_text() + track_indexes_sql = track_indexes_sql.replace(" CONCURRENTLY", "") conn = psycopg.connect(target_url, autocommit=True) with conn.cursor() as cur: cur.execute(functions_sql) - cur.execute(indexes_sql) + cur.execute(base_indexes_sql) + cur.execute(track_indexes_sql) conn.close() logger.info("Created functions and indexes on target database") diff --git a/tests/e2e/test_pipeline.py b/tests/e2e/test_pipeline.py index f49defa..1920ee2 100644 --- a/tests/e2e/test_pipeline.py +++ b/tests/e2e/test_pipeline.py @@ -110,7 +110,13 @@ def test_tables_populated(self) -> None: compilation releases, which may be pruned depending on matching. """ conn = self._connect() - for table in ("release", "release_artist", "release_track", "cache_metadata"): + for table in ( + "release", + "release_artist", + "release_label", + "release_track", + "cache_metadata", + ): with conn.cursor() as cur: cur.execute(f"SELECT count(*) FROM {table}") count = cur.fetchone()[0] @@ -195,7 +201,13 @@ def test_fk_constraints_exist(self) -> None: """) fk_tables = {row[0] for row in cur.fetchall()} conn.close() - expected = {"release_artist", "release_track", "release_track_artist", "cache_metadata"} + expected = { + "release_artist", + "release_label", + "release_track", + "release_track_artist", + "cache_metadata", + } assert expected.issubset(fk_tables) def test_null_title_release_not_imported(self) -> None: @@ -379,7 +391,13 @@ def test_target_has_indexes(self) -> None: def test_target_tables_populated(self) -> None: """Core tables in target have rows.""" conn = psycopg.connect(self.target_url) - for table in ("release", "release_artist", "release_track", "cache_metadata"): + for table in ( + "release", + "release_artist", + "release_label", + "release_track", + "cache_metadata", + ): with conn.cursor() as cur: cur.execute(f"SELECT count(*) FROM {table}") count = cur.fetchone()[0] diff --git a/tests/fixtures/create_fixtures.py b/tests/fixtures/create_fixtures.py index b5c613f..0e8a693 100644 --- a/tests/fixtures/create_fixtures.py +++ b/tests/fixtures/create_fixtures.py @@ -184,6 +184,44 @@ def create_release_track_artist_csv() -> None: write_csv("release_track_artist.csv", headers, rows) +def create_release_label_csv() -> None: + """Create release_label.csv with label names for releases. + + Includes: + - Multiple labels per release (release 1001 has Parlophone and Capitol Records) + - Labels for releases in the same dedup group (1001, 1002, 1003) + - Labels for releases that won't match the library (5001, 5002) + """ + headers = ["release_id", "label", "catno"] + rows = [ + # Radiohead - OK Computer (dedup group, master_id 500) + [1001, "Parlophone", "7243 8 55229 2 8"], + [1001, "Capitol Records", "CDP 7243 8 55229 2 8"], + [1002, "Capitol Records", "C1-55229"], + [1003, "EMI", "TOCP-50201"], + # Joy Division - Unknown Pleasures (dedup group, master_id 600) + [2001, "Factory Records", "FACT 10"], + [2002, "Qwest Records", "1-25840"], + # Unique releases + [3001, "Parlophone", "7243 5 27753 2 3"], + [4001, "Parlophone", "7243 5 32764 2 8"], + # Won't match library + [5001, "Unknown Label", "UNK-001"], + [5002, "Mystery Records", "MYS-002"], + # Bjork + [6001, "One Little Indian", "TPLP 71 CD"], + # Compilation + [8001, "Sugar Hill Records", "SH-542"], + # Beatles, Simon & Garfunkel + [9001, "Apple Records", "PCS 7088"], + [9002, "Columbia", "KCS 9914"], + # Not in library + [10001, "Random Label", "RL-001"], + [10002, "Obscure Label", "OL-002"], + ] + write_csv("release_label.csv", headers, rows) + + def create_release_image_csv() -> None: """Create release_image.csv for artwork URL testing.""" headers = ["release_id", "type", "width", "height", "uri"] @@ -290,6 +328,7 @@ def main() -> None: create_release_artist_csv() create_release_track_csv() create_release_track_artist_csv() + create_release_label_csv() create_release_image_csv() print() print("Library data:") diff --git a/tests/fixtures/csv/release_label.csv b/tests/fixtures/csv/release_label.csv new file mode 100644 index 0000000..aaffd2d --- /dev/null +++ b/tests/fixtures/csv/release_label.csv @@ -0,0 +1,17 @@ +release_id,label,catno +1001,Parlophone,7243 8 55229 2 8 +1001,Capitol Records,CDP 7243 8 55229 2 8 +1002,Capitol Records,C1-55229 +1003,EMI,TOCP-50201 +2001,Factory Records,FACT 10 +2002,Qwest Records,1-25840 +3001,Parlophone,7243 5 27753 2 3 +4001,Parlophone,7243 5 32764 2 8 +5001,Unknown Label,UNK-001 +5002,Mystery Records,MYS-002 +6001,One Little Indian,TPLP 71 CD +8001,Sugar Hill Records,SH-542 +9001,Apple Records,PCS 7088 +9002,Columbia,KCS 9914 +10001,Random Label,RL-001 +10002,Obscure Label,OL-002 diff --git a/tests/integration/test_dedup.py b/tests/integration/test_dedup.py index 7ce51b7..bb368f5 100644 --- a/tests/integration/test_dedup.py +++ b/tests/integration/test_dedup.py @@ -42,6 +42,7 @@ "cache_metadata", "release_track_artist", "release_track", + "release_label", "release_artist", "release", ) @@ -112,6 +113,12 @@ def _run_dedup(db_url: str) -> None: "release_id, artist_name, extra", "release_id", ), + ( + "release_label", + "new_release_label", + "release_id, label_name", + "release_id", + ), ( "cache_metadata", "new_cache_metadata", @@ -127,6 +134,7 @@ def _run_dedup(db_url: str) -> None: with conn.cursor() as cur: for stmt in [ "ALTER TABLE release_artist DROP CONSTRAINT IF EXISTS fk_release_artist_release", + "ALTER TABLE release_label DROP CONSTRAINT IF EXISTS fk_release_label_release", "ALTER TABLE cache_metadata DROP CONSTRAINT IF EXISTS fk_cache_metadata_release", ]: cur.execute(stmt) @@ -225,12 +233,35 @@ def test_child_table_rows_cleaned(self) -> None: with conn.cursor() as cur: cur.execute("SELECT count(*) FROM release_artist WHERE release_id = 1001") artist_count = cur.fetchone()[0] + cur.execute("SELECT count(*) FROM release_label WHERE release_id = 1001") + label_count = cur.fetchone()[0] cur.execute("SELECT count(*) FROM release_track WHERE release_id = 1001") track_count = cur.fetchone()[0] conn.close() assert artist_count == 0 + assert label_count == 0 assert track_count == 0 + def test_kept_release_labels_preserved(self) -> None: + """The kept release still has its labels after dedup.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute( + "SELECT label_name FROM release_label WHERE release_id = 1002 ORDER BY label_name" + ) + labels = [row[0] for row in cur.fetchall()] + conn.close() + assert labels == ["Capitol Records"] + + def test_deduped_release_has_no_labels(self) -> None: + """Releases removed by dedup have no labels.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM release_label WHERE release_id = 1001") + count = cur.fetchone()[0] + conn.close() + assert count == 0 + def test_kept_release_tracks_preserved(self) -> None: """The kept release still has its tracks (imported after dedup).""" conn = self._connect() @@ -275,7 +306,7 @@ def test_base_fk_constraints_recreated(self) -> None: """) fk_tables = {row[0] for row in cur.fetchall()} conn.close() - expected = {"release_artist", "cache_metadata"} + expected = {"release_artist", "release_label", "cache_metadata"} assert expected.issubset(fk_tables) def test_deduped_release_has_no_tracks(self) -> None: diff --git a/tests/integration/test_import.py b/tests/integration/test_import.py index fabdc53..589ee01 100644 --- a/tests/integration/test_import.py +++ b/tests/integration/test_import.py @@ -90,6 +90,27 @@ def test_release_artist_row_count(self) -> None: # 16 rows in fixture CSV (all have required fields) assert count == 16 + def test_release_label_row_count(self) -> None: + """All label rows imported (one per unique release_id+label pair).""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM release_label") + count = cur.fetchone()[0] + conn.close() + # 16 rows in fixture CSV, all unique (release_id, label) pairs + assert count == 16 + + def test_release_label_column_mapping(self) -> None: + """CSV 'label' column maps to DB 'label_name'.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute( + "SELECT label_name FROM release_label WHERE release_id = 1001 ORDER BY label_name" + ) + labels = [row[0] for row in cur.fetchall()] + conn.close() + assert labels == ["Capitol Records", "Parlophone"] + def test_release_track_row_count(self) -> None: conn = self._connect() with conn.cursor() as cur: @@ -210,6 +231,7 @@ def test_cache_metadata_source(self) -> None: "cache_metadata", "release_track_artist", "release_track", + "release_label", "release_artist", "release", ) diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index b874512..87897eb 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -40,6 +40,7 @@ def test_all_tables_exist(self) -> None: expected = { "release", "release_artist", + "release_label", "release_track", "release_track_artist", "cache_metadata", @@ -59,6 +60,7 @@ def test_all_tables_exist(self) -> None: [ ("release", {"id", "title", "release_year", "artwork_url", "master_id"}), ("release_artist", {"release_id", "artist_name", "extra"}), + ("release_label", {"release_id", "label_name"}), ("release_track", {"release_id", "sequence", "position", "title", "duration"}), ("release_track_artist", {"release_id", "track_sequence", "artist_name"}), ("cache_metadata", {"release_id", "cached_at", "source", "last_validated"}), @@ -66,6 +68,7 @@ def test_all_tables_exist(self) -> None: ids=[ "release", "release_artist", + "release_label", "release_track", "release_track_artist", "cache_metadata", @@ -127,6 +130,7 @@ def test_fk_constraints_with_cascade(self) -> None: conn.close() expected_fk_tables = { "release_artist", + "release_label", "release_track", "release_track_artist", "cache_metadata", @@ -145,7 +149,7 @@ def test_no_unique_constraints_on_child_tables(self) -> None: SELECT tc.table_name, tc.constraint_name FROM information_schema.table_constraints tc WHERE tc.constraint_type = 'UNIQUE' - AND tc.table_name IN ('release_artist', 'release_track_artist') + AND tc.table_name IN ('release_artist', 'release_label', 'release_track_artist') """) unique_constraints = cur.fetchall() conn.close() diff --git a/tests/unit/test_import_csv.py b/tests/unit/test_import_csv.py index 1701191..27657dd 100644 --- a/tests/unit/test_import_csv.py +++ b/tests/unit/test_import_csv.py @@ -111,7 +111,7 @@ def test_release_table_transforms_released_to_year(self) -> None: @pytest.mark.parametrize( "table_name", - ["release", "release_artist", "release_track", "release_track_artist"], + ["release", "release_artist", "release_label", "release_track", "release_track_artist"], ) def test_table_has_csv_file(self, table_name: str) -> None: """Each table config specifies a CSV file.""" @@ -129,7 +129,7 @@ def test_all_tables_have_required_keys(self) -> None: def test_tables_with_unique_constraints_have_unique_key(self) -> None: """Tables with unique constraints must specify unique_key for dedup during import.""" - tables_needing_dedup = {"release_artist", "release_track_artist"} + tables_needing_dedup = {"release_artist", "release_label", "release_track_artist"} for table_config in TABLES: if table_config["table"] in tables_needing_dedup: assert "unique_key" in table_config, ( @@ -178,6 +178,20 @@ def test_release_artist_csv_has_expected_columns(self) -> None: f"Expected column {col!r} not in release_artist.csv headers: {headers}" ) + def test_release_label_csv_has_expected_columns(self) -> None: + import csv as csv_mod + + csv_path = Path(__file__).parent.parent / "fixtures" / "csv" / "release_label.csv" + with open(csv_path) as f: + reader = csv_mod.DictReader(f) + headers = reader.fieldnames + assert headers is not None + rl_config = next(t for t in TABLES if t["table"] == "release_label") + for col in rl_config["csv_columns"]: + assert col in headers, ( + f"Expected column {col!r} not in release_label.csv headers: {headers}" + ) + # --------------------------------------------------------------------------- # count_tracks_from_csv @@ -230,9 +244,9 @@ class TestTableSplit: def test_tables_is_union(self) -> None: assert TABLES == BASE_TABLES + TRACK_TABLES - def test_base_tables_are_release_and_release_artist(self) -> None: + def test_base_tables_names(self) -> None: names = [t["table"] for t in BASE_TABLES] - assert names == ["release", "release_artist"] + assert names == ["release", "release_artist", "release_label"] def test_track_tables_are_release_track_and_release_track_artist(self) -> None: names = [t["table"] for t in TRACK_TABLES]