diff --git a/lib/conflicts/conflicts/agents/doctor_agent.py b/lib/conflicts/conflicts/agents/doctor_agent.py index c7c5ca1..d1aff8f 100644 --- a/lib/conflicts/conflicts/agents/doctor_agent.py +++ b/lib/conflicts/conflicts/agents/doctor_agent.py @@ -3,7 +3,7 @@ from typing import Any, Dict from ..core.base import BaseAgent -from ..core.models import ConflictResult, DocumentPair, PropositionResult +from ..core.models import ConflictPair, ConflictResult, DocumentPair, PropositionResult from ..core.temporal_analysis import TemporalAnalyzer prompts_dir = Path(__file__).parent.parent.parent / "prompts" @@ -100,29 +100,50 @@ def __call__( parsed_response = self._parse_json_response(response) # Validate required fields - required_fields = ["conflict_type", "reasoning", "modification_instructions"] - for field in required_fields: - if field not in parsed_response: - raise ValueError(f"Missing required field '{field}' in Doctor Agent response") - - # Validate conflict type exists - if parsed_response["conflict_type"] not in self.conflict_types: - self.logger.warning( - f"Unknown conflict type '{parsed_response['conflict_type']}', \ - defaulting to 'opposition'" + if "conflict_pairs" not in parsed_response: + raise ValueError("Missing required field 'conflict_pairs' in Doctor Agent response") + + conflict_pairs = [] + for pair_data in parsed_response["conflict_pairs"]: + # Validate required fields for each conflict pair + required_fields = ["conflict_type", "reasoning", "modification_instructions"] + for field in required_fields: + if field not in pair_data: + self.logger.warning( + f"Skipping conflict pair due to missing required field '{field}'" + ) + continue + + # Validate conflict type exists + if pair_data["conflict_type"] not in self.conflict_types: + self.logger.warning( + f"Skipping invalid conflict type '{pair_data['conflict_type']}'. " + f"Valid types: {list(self.conflict_types.keys())}" + ) + continue + + conflict_pair = ConflictPair( + conflict_type=pair_data["conflict_type"], + reasoning=pair_data["reasoning"], + modification_instructions=pair_data["modification_instructions"], + editor_instructions=pair_data.get("editor_instructions", []), + proposition_conflicts=pair_data.get("proposition_conflicts", []), ) - parsed_response["conflict_type"] = "opposition" - - result = ConflictResult( - conflict_type=parsed_response["conflict_type"], - reasoning=parsed_response["reasoning"], - modification_instructions=parsed_response["modification_instructions"], - editor_instructions=parsed_response.get("editor_instructions", []), - proposition_conflicts=parsed_response.get("proposition_conflicts", []), - ) + conflict_pairs.append(conflict_pair) + + # Check if we have any valid conflict pairs + if not conflict_pairs: + raise ValueError( + "No valid conflict pairs found in Doctor Agent response. " + "All conflict pairs were skipped due to validation errors." + ) + + result = ConflictResult(conflict_pairs=conflict_pairs) self.logger.info("Doctor Agent completed analysis") - self.logger.info(f"Selected conflict type: {result.conflict_type}") + self.logger.info(f"Selected {len(result.conflict_pairs)} conflict pairs") + for pair in result.conflict_pairs: + self.logger.info(f" - {pair.conflict_type}: {pair.reasoning[:100]}...") self.logger.info( f"Temporal context: {temporal_analysis.get('time_context', 'Unknown')}" ) diff --git a/lib/conflicts/conflicts/agents/editor_agent.py b/lib/conflicts/conflicts/agents/editor_agent.py index 71ebd98..3d06d9a 100644 --- a/lib/conflicts/conflicts/agents/editor_agent.py +++ b/lib/conflicts/conflicts/agents/editor_agent.py @@ -2,7 +2,7 @@ from ..core.base import BaseAgent from ..core.document_operations import parse_response -from ..core.models import ConflictResult, DocumentPair, EditorResult +from ..core.models import ConflictPair, DocumentPair, EditorResult prompts_dir = Path(__file__).parent.parent.parent / "prompts" EDITOR_SYSTEM_PROMPT_PATH = prompts_dir / "editor_agent_system.txt" @@ -21,7 +21,7 @@ def __init__(self, client, model, cfg): self.min_text_length = cfg.editor.min_text_length def __call__( - self, document_pair: DocumentPair, conflict_instructions: ConflictResult + self, document_pair: DocumentPair, conflict_instructions: ConflictPair ) -> EditorResult: """ Modify documents to introduce the specified conflict @@ -54,7 +54,7 @@ def __call__( self.logger.warning(f"Attempt {attempt + 1} failed: {e}, retrying...") def _perform_modification( - self, document_pair: DocumentPair, conflict_instructions: ConflictResult + self, document_pair: DocumentPair, conflict_instructions: ConflictPair ) -> EditorResult: """Perform a single modification attempt""" prompt = self._build_prompt(document_pair, conflict_instructions) @@ -63,7 +63,7 @@ def _perform_modification( return self._create_result(parsed_result, document_pair) def _build_prompt( - self, document_pair: DocumentPair, conflict_instructions: ConflictResult + self, document_pair: DocumentPair, conflict_instructions: ConflictPair ) -> str: """Build the prompt for modification""" # Extract specific propositions for each document diff --git a/lib/conflicts/conflicts/core/base.py b/lib/conflicts/conflicts/core/base.py index 156b73c..fe40fd1 100644 --- a/lib/conflicts/conflicts/core/base.py +++ b/lib/conflicts/conflicts/core/base.py @@ -46,6 +46,9 @@ class DocumentData: timestamp_1: Optional[str] timestamp_2: Optional[str] created_at: Optional[str] + moderator_score: Optional[int] = None + moderator_reasoning: Optional[str] = None + conflict_type: Optional[str] = None @dataclass @@ -185,6 +188,9 @@ def save_validated_documents( created_at=datetime.now().isoformat(), timestamp_1=str(original_pair.doc1_timestamp) if original_pair.doc1_timestamp else None, timestamp_2=str(original_pair.doc2_timestamp) if original_pair.doc2_timestamp else None, + moderator_score=validation_result.score, + moderator_reasoning=validation_result.reasoning, + conflict_type=conflict_type, ) # Create annotations list diff --git a/lib/conflicts/conflicts/core/models.py b/lib/conflicts/conflicts/core/models.py index a618582..1fbd6c5 100644 --- a/lib/conflicts/conflicts/core/models.py +++ b/lib/conflicts/conflicts/core/models.py @@ -19,8 +19,8 @@ class DocumentPair: @dataclass -class ConflictResult: - """Result from the Doctor Agent""" +class ConflictPair: + """Represents a single conflict pair for a specific conflict type""" conflict_type: str reasoning: str @@ -29,6 +29,13 @@ class ConflictResult: proposition_conflicts: Optional[list[dict]] = None +@dataclass +class ConflictResult: + """Result from the Doctor Agent - now contains multiple conflict pairs""" + + conflict_pairs: list[ConflictPair] + + @dataclass class EditorResult: """Result from the Editor Agent""" diff --git a/lib/conflicts/conflicts/core/pipeline.py b/lib/conflicts/conflicts/core/pipeline.py index 4dc720b..510a5d6 100644 --- a/lib/conflicts/conflicts/core/pipeline.py +++ b/lib/conflicts/conflicts/core/pipeline.py @@ -134,16 +134,15 @@ def process_document_pair(self, document_pair: DocumentPair) -> Tuple[bool, Dict result_data = { "pair_id": pair_id, "success": False, - "conflict_type": None, + "conflict_pairs": 0, + "successful_pairs": 0, + "total_pairs": 0, "processing_time": 0, "proposition_result": None, "doctor_result": None, - "editor_result": None, - "moderator_result": None, + "conflict_pair_results": [], "proposition_time": 0, "doctor_time": 0, - "editor_time": 0, - "moderator_time": 0, } # Step 1: Proposition Agent decomposes documents into propositions @@ -155,78 +154,111 @@ def process_document_pair(self, document_pair: DocumentPair) -> Tuple[bool, Dict result_data["proposition_result"] = proposition_result result_data["proposition_time"] = proposition_time - # Step 2: Doctor Agent identifies conflict type using propositions - + # Step 2: Doctor Agent identifies conflict pairs using propositions conflict_result, doctor_time = self._execute_agent( self.doctor_agent, document_pair, proposition_result[0], proposition_result[1] ) result_data["doctor_result"] = conflict_result result_data["doctor_time"] = doctor_time - result_data["conflict_type"] = conflict_result.conflict_type + result_data["conflict_pairs"] = len(conflict_result.conflict_pairs) - # Step 3: Editor and Moderator agents with retry logic for editor only - validation_result = None - editor_result = None + # Step 3: Process each conflict pair through Editor and Moderator agents + all_results = [] + successful_pairs = 0 - for attempt in range(1, self.max_retries + 1): - # Execute editor agent - editor_result, editor_time = self._execute_agent( - self.editor_agent, document_pair, conflict_result + for i, conflict_pair in enumerate(conflict_result.conflict_pairs): + self.logger.info( + f"Processing conflict pair {i+1}/{len(conflict_result.conflict_pairs)}:" + f" {conflict_pair.conflict_type}" ) - result_data["editor_result"] = editor_result - result_data["editor_time"] = editor_time - # Check if editor agent failed to create modifications - if "Failed to create conflict" in editor_result.changes_made: - self.logger.warning( - "Editor agent failed to create modifications, skipping moderator validation" + pair_result = { + "conflict_type": conflict_pair.conflict_type, + "success": False, + "editor_result": None, + "moderator_result": None, + "editor_time": 0, + "moderator_time": 0, + "attempts": 0, + } + + validation_result = None + editor_result = None + + for attempt in range(1, self.max_retries + 1): + pair_result["attempts"] = attempt + + # Execute editor agent + editor_result, editor_time = self._execute_agent( + self.editor_agent, document_pair, conflict_pair ) - validation_result = ValidationResult( - is_valid=False, - score=1, - reasoning="Editor agent failed to modify - no changes to validate", + pair_result["editor_result"] = editor_result + pair_result["editor_time"] = editor_time + + # Check if editor agent failed to create modifications + if "Failed to create conflict" in editor_result.changes_made: + self.logger.warning( + f"Editor agent failed to create modifications for " + f"{conflict_pair.conflict_type}, skipping moderator validation" + ) + validation_result = ValidationResult( + is_valid=False, + score=1, + reasoning="Editor agent failed to modify - no changes to validate", + ) + pair_result["moderator_result"] = validation_result + pair_result["moderator_time"] = 0 + break + + # Execute moderator agent for validation + validation_result, moderator_time = self._execute_agent( + self.moderator_agent, document_pair, editor_result, conflict_pair.conflict_type ) - result_data["moderator_result"] = validation_result - result_data["moderator_time"] = 0 - break + pair_result["moderator_result"] = validation_result + pair_result["moderator_time"] = moderator_time - # Execute moderator agent for validation - validation_result, moderator_time = self._execute_agent( - self.moderator_agent, document_pair, editor_result, conflict_result.conflict_type - ) - result_data["moderator_result"] = validation_result - result_data["moderator_time"] = moderator_time + self.logger.info( + f"Attempt {attempt}: {conflict_pair.conflict_type} conflict, " + f"valid={validation_result.is_valid}, score={validation_result.score}/5" + ) - self.logger.info( - f"Attempt {attempt}: {conflict_result.conflict_type} conflict, " - f"valid={validation_result.is_valid}, score={validation_result.score}/5" - ) + if validation_result.is_valid: + pair_result["success"] = True + successful_pairs += 1 + break + + if attempt < self.max_retries: + self.logger.warning( + f"Validation failed for {conflict_pair.conflict_type}, retrying..." + ) + time.sleep(1) + + # Save to database if validation passed + if validation_result and validation_result.is_valid: + is_success = self._save_to_database( + f"{pair_id}_{conflict_pair.conflict_type}", + document_pair, + editor_result, + conflict_pair.conflict_type, + validation_result, + ) + pair_result["success"] = is_success - if validation_result.is_valid: - result_data["success"] = True - break - - if attempt < self.max_retries: - self.logger.warning("Validation failed, retrying...") - time.sleep(1) - - # Step 3: Save to database if validation passed - if validation_result and validation_result.is_valid: - is_success = self._save_to_database( - pair_id, - document_pair, - editor_result, - conflict_result.conflict_type, - validation_result, - ) - result_data["success"] = is_success + all_results.append(pair_result) + + # Update result data with all conflict pair results + result_data["conflict_pair_results"] = all_results + result_data["successful_pairs"] = successful_pairs + result_data["total_pairs"] = len(conflict_result.conflict_pairs) + result_data["success"] = successful_pairs > 0 # Success if at least one pair succeeded result_data["processing_time"] = time.time() - start_time # Summary log status = "SUCCESS" if result_data["success"] else "FAILED" self.logger.info( - f"Pair {pair_id}: {status} - {conflict_result.conflict_type} conflict, " + f"{pair_id}: {status} - {result_data['successful_pairs']}/{result_data['total_pairs']} " + "conflict pairs successful, " f"{proposition_result[0].total_propositions + proposition_result[1].total_propositions}" f" propositions" ) @@ -272,16 +304,15 @@ def execute( failed_result = { "pair_id": f"{doc_pair.doc1_id}_{doc_pair.doc2_id}", "success": False, - "conflict_type": None, + "conflict_pairs": 0, + "successful_pairs": 0, + "total_pairs": 0, "processing_time": 0, "proposition_result": None, "doctor_result": None, - "editor_result": None, - "moderator_result": None, + "conflict_pair_results": [], "proposition_time": 0, "doctor_time": 0, - "editor_time": 0, - "moderator_time": 0, "error": str(e), } results.append(failed_result) diff --git a/lib/conflicts/label-studio.xml b/lib/conflicts/label-studio.xml index 1bdeb5b..0c81758 100644 --- a/lib/conflicts/label-studio.xml +++ b/lib/conflicts/label-studio.xml @@ -17,15 +17,64 @@ -
-
+ + +
+ +
+ + +
+ + +
+ + + + +
+
-
- -
-