Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions lib/conflicts/scripts/postprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import argparse
import hashlib
import json
import random
import subprocess
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Tuple


def get_git_sha() -> str:
"""Get current git SHA-1 hash."""
try:
result = subprocess.run(
["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True
)
return result.stdout.strip()[:7] # Short SHA
except subprocess.CalledProcessError:
return "nogit"


def get_pair_id(item: Dict[str, Any]) -> str:
"""Extract unique pair ID from data item."""
data = item.get("data", {})
doc1 = data.get("doc_1", "")
doc2 = data.get("doc_2", "")

# Create unique identifier from document content only
content = f"{doc1}{doc2}"
return hashlib.md5(content.encode()).hexdigest()[:8]


def calculate_average_score(item: Dict[str, Any]) -> float:
"""Calculate average of the three individual scores."""
clinical_plausibility = item.get("clinical_plausibility_score", 0)
temporal_appropriateness = item.get("temporal_appropriateness_score", 0)
clinical_significance = item.get("clinical_significance_score", 0)

scores = [clinical_plausibility, temporal_appropriateness, clinical_significance]
valid_scores = [s for s in scores if s is not None and s != 0]

return sum(valid_scores) / len(valid_scores) if valid_scores else 0.0


def load_and_process_files(file_paths: List[str]) -> List[Dict[str, Any]]:
"""Load and aggregate data from multiple JSON files."""
all_data = []

for file_path in file_paths:
try:
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
all_data.extend(data if isinstance(data, list) else [data])
print(f"Loaded {file_path}")
except Exception as e:
print(f"Error loading {file_path}: {e}")

return all_data


def create_splits(
data: List[Dict[str, Any]], best_limit: int, low_score_limit: int
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Create two splits based on criteria."""
used_pair_ids = set()
best_split = []
low_score_split = []

# Split 1: Best propositions
for item in data:
if len(best_split) >= best_limit:
break
pair_id = get_pair_id(item)
is_best = item.get("data", {}).get("best_conflict", False)

if is_best and pair_id not in used_pair_ids:
best_split.append(item)
used_pair_ids.add(pair_id)

# Split 2: Low-scored propositions
for item in data:
if len(low_score_split) >= low_score_limit:
break
pair_id = get_pair_id(item)
avg_score = calculate_average_score(item)

if avg_score < 4.0 and pair_id not in used_pair_ids:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This value will be used both in moderator agents & post-processing steps, maybe it is worth to save in a constants.py/utils.py for later consistency.

low_score_split.append(item)
used_pair_ids.add(pair_id)

Comment on lines +69 to +90
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function iterates through the entire dataset twice. I think it is possible to be done in a single pass.

print(f"Best split: {len(best_split)} items")
print(f"Low score split: {len(low_score_split)} items")

return best_split, low_score_split


def create_output_filename() -> str:
"""Create output filename with git SHA and date."""
git_sha = get_git_sha()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both create_output_filename and create_metadata called get_git_sha() at the beginning, it's better to call once in main() directly before calling these 2 function to avoid redundant calls.

date_str = datetime.now().strftime("%d%m%Y")
return f"{git_sha}_{date_str}"


def create_metadata(
input_files: List[str],
best_count: int,
low_score_count: int,
total_count: int,
output_filename: str,
) -> str:
"""Create metadata content."""
git_sha = get_git_sha()
timestamp = datetime.now().isoformat()

metadata = f"""# Dataset Postprocessing Metadata

## Postprocessed Information
- **Postprocessed Date**: {timestamp}
- **Git SHA**: {git_sha}
- **Output File**: {output_filename}.json

## Input Files
"""

for file_path in input_files:
metadata += f"- `{Path(file_path).name}`\n"

metadata += f"""
## Postprocessed Results
- **Total Input Items**: {total_count}
- **Best Propositions Split**: {best_count} items (unique pair IDs)
- **Low Score Split**: {low_score_count} items (average score < 4.0, unique pair IDs)
- **Final Dataset Size**: {best_count + low_score_count} items
"""

return metadata


def main():
parser = argparse.ArgumentParser(description="Post-process conflict detection dataset files")
parser.add_argument("files", nargs="+", help="JSON files to process")
parser.add_argument(
"--best-limit",
type=int,
default=50,
help="Maximum number of best conflict items (default: 50)",
)
parser.add_argument(
"--low-score-limit",
type=int,
default=50,
help="Maximum number of low score items (default: 50)",
)
args = parser.parse_args()

# Load and process data
print(f"Processing {len(args.files)} files...")
all_data = load_and_process_files(args.files)

print(f"Total items loaded: {len(all_data)}")

# Create splits
best_split, low_score_split = create_splits(all_data, args.best_limit, args.low_score_limit)
merged_data = best_split + low_score_split

# Shuffle the merged data
random.shuffle(merged_data)

# Create output directory
output_dir = Path("postprocessed")
output_dir.mkdir(exist_ok=True)

# Create output filename
output_filename = create_output_filename()

# Write JSON output
json_path = output_dir / f"{output_filename}.json"
with open(json_path, "w", encoding="utf-8") as f:
json.dump(merged_data, f, indent=2, ensure_ascii=False)

# Write metadata
metadata_content = create_metadata(
args.files, len(best_split), len(low_score_split), len(all_data), output_filename
)

md_path = output_dir / f"{output_filename}.md"
with open(md_path, "w", encoding="utf-8") as f:
f.write(metadata_content)

print("Output files created:")
print(f" - {json_path}")
print(f" - {md_path}")
print(
f"\nFinal dataset: {len(merged_data)} items ({len(best_split)} best "
f"{len(low_score_split)} low score)"
)

return 0


if __name__ == "__main__":
exit(main())