-
Notifications
You must be signed in to change notification settings - Fork 1
add postprocessing script to select only 50% of best examples and 50% of random examples #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: ht+sm/conflict-maybe-lrec
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,202 @@ | ||
| import argparse | ||
| import hashlib | ||
| import json | ||
| import random | ||
| import subprocess | ||
| from datetime import datetime | ||
| from pathlib import Path | ||
| from typing import Any, Dict, List, Tuple | ||
|
|
||
|
|
||
| def get_git_sha() -> str: | ||
| """Get current git SHA-1 hash.""" | ||
| try: | ||
| result = subprocess.run( | ||
| ["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True | ||
| ) | ||
| return result.stdout.strip()[:7] # Short SHA | ||
| except subprocess.CalledProcessError: | ||
| return "nogit" | ||
|
|
||
|
|
||
| def get_pair_id(item: Dict[str, Any]) -> str: | ||
| """Extract unique pair ID from data item.""" | ||
| data = item.get("data", {}) | ||
| doc1 = data.get("doc_1", "") | ||
| doc2 = data.get("doc_2", "") | ||
|
|
||
| # Create unique identifier from document content only | ||
| content = f"{doc1}{doc2}" | ||
| return hashlib.md5(content.encode()).hexdigest()[:8] | ||
|
|
||
|
|
||
| def calculate_average_score(item: Dict[str, Any]) -> float: | ||
| """Calculate average of the three individual scores.""" | ||
| clinical_plausibility = item.get("clinical_plausibility_score", 0) | ||
| temporal_appropriateness = item.get("temporal_appropriateness_score", 0) | ||
| clinical_significance = item.get("clinical_significance_score", 0) | ||
|
|
||
| scores = [clinical_plausibility, temporal_appropriateness, clinical_significance] | ||
| valid_scores = [s for s in scores if s is not None and s != 0] | ||
|
|
||
| return sum(valid_scores) / len(valid_scores) if valid_scores else 0.0 | ||
|
|
||
|
|
||
| def load_and_process_files(file_paths: List[str]) -> List[Dict[str, Any]]: | ||
| """Load and aggregate data from multiple JSON files.""" | ||
| all_data = [] | ||
|
|
||
| for file_path in file_paths: | ||
| try: | ||
| with open(file_path, "r", encoding="utf-8") as f: | ||
| data = json.load(f) | ||
| all_data.extend(data if isinstance(data, list) else [data]) | ||
| print(f"Loaded {file_path}") | ||
| except Exception as e: | ||
| print(f"Error loading {file_path}: {e}") | ||
|
|
||
| return all_data | ||
|
|
||
|
|
||
| def create_splits( | ||
| data: List[Dict[str, Any]], best_limit: int, low_score_limit: int | ||
| ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: | ||
| """Create two splits based on criteria.""" | ||
| used_pair_ids = set() | ||
| best_split = [] | ||
| low_score_split = [] | ||
|
|
||
| # Split 1: Best propositions | ||
| for item in data: | ||
| if len(best_split) >= best_limit: | ||
| break | ||
| pair_id = get_pair_id(item) | ||
| is_best = item.get("data", {}).get("best_conflict", False) | ||
|
|
||
| if is_best and pair_id not in used_pair_ids: | ||
| best_split.append(item) | ||
| used_pair_ids.add(pair_id) | ||
|
|
||
| # Split 2: Low-scored propositions | ||
| for item in data: | ||
| if len(low_score_split) >= low_score_limit: | ||
| break | ||
| pair_id = get_pair_id(item) | ||
| avg_score = calculate_average_score(item) | ||
|
|
||
| if avg_score < 4.0 and pair_id not in used_pair_ids: | ||
| low_score_split.append(item) | ||
| used_pair_ids.add(pair_id) | ||
|
|
||
|
Comment on lines
+69
to
+90
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function iterates through the entire dataset twice. I think it is possible to be done in a single pass. |
||
| print(f"Best split: {len(best_split)} items") | ||
| print(f"Low score split: {len(low_score_split)} items") | ||
|
|
||
| return best_split, low_score_split | ||
|
|
||
|
|
||
| def create_output_filename() -> str: | ||
| """Create output filename with git SHA and date.""" | ||
| git_sha = get_git_sha() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both |
||
| date_str = datetime.now().strftime("%d%m%Y") | ||
| return f"{git_sha}_{date_str}" | ||
|
|
||
|
|
||
| def create_metadata( | ||
| input_files: List[str], | ||
| best_count: int, | ||
| low_score_count: int, | ||
| total_count: int, | ||
| output_filename: str, | ||
| ) -> str: | ||
| """Create metadata content.""" | ||
| git_sha = get_git_sha() | ||
| timestamp = datetime.now().isoformat() | ||
|
|
||
| metadata = f"""# Dataset Postprocessing Metadata | ||
|
|
||
| ## Postprocessed Information | ||
| - **Postprocessed Date**: {timestamp} | ||
| - **Git SHA**: {git_sha} | ||
| - **Output File**: {output_filename}.json | ||
|
|
||
| ## Input Files | ||
| """ | ||
|
|
||
| for file_path in input_files: | ||
| metadata += f"- `{Path(file_path).name}`\n" | ||
|
|
||
| metadata += f""" | ||
| ## Postprocessed Results | ||
| - **Total Input Items**: {total_count} | ||
| - **Best Propositions Split**: {best_count} items (unique pair IDs) | ||
| - **Low Score Split**: {low_score_count} items (average score < 4.0, unique pair IDs) | ||
| - **Final Dataset Size**: {best_count + low_score_count} items | ||
| """ | ||
|
|
||
| return metadata | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Post-process conflict detection dataset files") | ||
| parser.add_argument("files", nargs="+", help="JSON files to process") | ||
| parser.add_argument( | ||
| "--best-limit", | ||
| type=int, | ||
| default=50, | ||
| help="Maximum number of best conflict items (default: 50)", | ||
| ) | ||
| parser.add_argument( | ||
| "--low-score-limit", | ||
| type=int, | ||
| default=50, | ||
| help="Maximum number of low score items (default: 50)", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| # Load and process data | ||
| print(f"Processing {len(args.files)} files...") | ||
| all_data = load_and_process_files(args.files) | ||
|
|
||
| print(f"Total items loaded: {len(all_data)}") | ||
|
|
||
| # Create splits | ||
| best_split, low_score_split = create_splits(all_data, args.best_limit, args.low_score_limit) | ||
| merged_data = best_split + low_score_split | ||
|
|
||
| # Shuffle the merged data | ||
| random.shuffle(merged_data) | ||
|
|
||
| # Create output directory | ||
| output_dir = Path("postprocessed") | ||
| output_dir.mkdir(exist_ok=True) | ||
|
|
||
| # Create output filename | ||
| output_filename = create_output_filename() | ||
|
|
||
| # Write JSON output | ||
| json_path = output_dir / f"{output_filename}.json" | ||
| with open(json_path, "w", encoding="utf-8") as f: | ||
| json.dump(merged_data, f, indent=2, ensure_ascii=False) | ||
|
|
||
| # Write metadata | ||
| metadata_content = create_metadata( | ||
| args.files, len(best_split), len(low_score_split), len(all_data), output_filename | ||
| ) | ||
|
|
||
| md_path = output_dir / f"{output_filename}.md" | ||
| with open(md_path, "w", encoding="utf-8") as f: | ||
| f.write(metadata_content) | ||
|
|
||
| print("Output files created:") | ||
| print(f" - {json_path}") | ||
| print(f" - {md_path}") | ||
| print( | ||
| f"\nFinal dataset: {len(merged_data)} items ({len(best_split)} best " | ||
| f"{len(low_score_split)} low score)" | ||
| ) | ||
|
|
||
| return 0 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| exit(main()) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This value will be used both in moderator agents & post-processing steps, maybe it is worth to save in a
constants.py/utils.pyfor later consistency.