Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 39 additions & 111 deletions lib/conflicts/conflicts/agents/doctor_agent.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict

from ..core.base import BaseAgent
from ..core.constants import (
PRE_POST_CARE_CONFLICT_TYPE,
SPECIALIZED_CONFLICT_TYPES,
TEMPORALITY_CONFLICT_TYPE,
)
from ..core.models import ConflictResult, DocumentPair, PropositionResult
from ..core.temporal_analysis import TemporalAnalyzer

prompts_dir = Path(__file__).parent.parent.parent / "prompts"
DOCTOR_SYSTEM_PROMPT_PATH = prompts_dir / "doctor_agent_system.txt"


@dataclass
class ConflictType:
"""Represents a type of clinical conflict"""

name: str
description: str
examples: list[str]
DOCTOR_PRE_POST_CARE_PROMPT_PATH = prompts_dir / "doctor_agent_pre_post_care_system.txt"
DOCTOR_TEMPORALITY_PROMPT_PATH = prompts_dir / "doctor_agent_temporality_system.txt"


class DoctorAgent(BaseAgent):
Expand All @@ -25,54 +19,57 @@ class DoctorAgent(BaseAgent):
clinical conflict should be introduced between them.
"""

def __init__(self, client, model, cfg):
with open(DOCTOR_SYSTEM_PROMPT_PATH, "r", encoding="utf-8") as f:
def __init__(self, client, model, cfg, conflict_type: str):
if conflict_type not in SPECIALIZED_CONFLICT_TYPES:
raise ValueError(
f"Invalid conflict_type: {conflict_type}. "
f"Must be one of: {SPECIALIZED_CONFLICT_TYPES}"
)

# Load appropriate prompt based on conflict type
prompt_path, agent_name = self._get_prompt_path_and_name(conflict_type)

with open(prompt_path, "r", encoding="utf-8") as f:
prompt = f.read().strip()
super().__init__("Doctor", client, model, cfg, prompt)
self.conflict_types = self._load_conflict_types(cfg)
super().__init__(agent_name, client, model, cfg, prompt)
self.conflict_type_specialization = conflict_type

@staticmethod
def _get_prompt_path_and_name(conflict_type: str):
"""Get prompt path and agent name based on conflict type"""
if conflict_type == PRE_POST_CARE_CONFLICT_TYPE:
return DOCTOR_PRE_POST_CARE_PROMPT_PATH, "Doctor-PrePostCare"
elif conflict_type == TEMPORALITY_CONFLICT_TYPE:
return DOCTOR_TEMPORALITY_PROMPT_PATH, "Doctor-Temporality"
else:
raise ValueError(f"Unknown conflict type: {conflict_type}")

def __call__(
self,
document_pair: DocumentPair,
propositions1: PropositionResult = None,
propositions2: PropositionResult = None,
conflict_type: str = None,
conflict_type: str = None, # Kept for API compatibility but not used
) -> ConflictResult:
"""
Analyze documents and choose proposition pairs for a specific conflict type
Analyze documents and choose proposition pairs for the specialized conflict type

Args:
document_pair: Pair of clinical documents to analyze
propositions1: Optional PropositionResult from document 1
propositions2: Optional PropositionResult from document 2
conflict_type: Specific conflict type to create
conflict_type: Not used (kept for API compatibility)

Returns:
ConflictResult containing the chosen proposition pairs and instructions
"""
self.logger.info(
f"Analyzing document pair: {document_pair.doc1_id} & {document_pair.doc2_id}"
f" for conflict type: {conflict_type}"
f" for conflict type: {self.conflict_type_specialization}"
)

try:
# Perform temporal analysis
temporal_analyzer = TemporalAnalyzer()
temporal_analysis = temporal_analyzer.analyze_temporal_relationship(
document_pair.doc1_timestamp, document_pair.doc2_timestamp
)

# Get temporal conflict recommendations
temporal_recommendations = temporal_analyzer.get_temporal_conflict_recommendations(
temporal_analysis
)

# Prepare prompt with conflict types, temporal info, and documents
conflict_type_info = self.get_conflict_type_info(conflict_type)
temporal_context = temporal_analyzer.format_temporal_context_for_prompt(
temporal_analysis
)
temporal_recommendations_str = ", ".join(temporal_recommendations)
# Prepare propositions strings
propositions1_str = (
"\n".join([f"{i}. {prop}" for i, prop in enumerate(propositions1.propositions, 1)])
if propositions1 and propositions1.propositions
Expand All @@ -85,14 +82,6 @@ def __call__(
)

prompt = self.system_prompt.format(
conflict_type=conflict_type,
conflict_type_name=conflict_type_info["name"],
conflict_type_description=conflict_type_info["description"],
conflict_type_examples="\n".join(
[f"- {example}" for example in conflict_type_info["examples"]]
),
temporal_context=temporal_context,
temporal_recommendations=temporal_recommendations_str,
document1=self._truncate_document(document_pair.doc1_text),
document2=self._truncate_document(document_pair.doc2_text),
propositions1=propositions1_str,
Expand All @@ -113,8 +102,11 @@ def __call__(
if field not in parsed_response:
raise ValueError(f"Missing required field '{field}' in Doctor Agent response")

# Use the specialized conflict type
result_conflict_type = self.conflict_type_specialization

result = ConflictResult(
conflict_type=conflict_type,
conflict_type=result_conflict_type,
reasoning=parsed_response["reasoning"],
modification_instructions=parsed_response["modification_instructions"],
editor_instructions=parsed_response.get("editor_instructions", []),
Expand All @@ -123,73 +115,9 @@ def __call__(

self.logger.info("Doctor Agent completed analysis")
self.logger.info(f"Selected conflict type: {result.conflict_type}")
self.logger.info(
f"Temporal context: {temporal_analysis.get('time_context', 'Unknown')}"
)

return result

except Exception as e:
self.logger.error(f"Error in Doctor Agent: {e}")
raise

def get_conflict_type_info(self, conflict_type: str) -> Dict[str, Any]:
"""
Get information about a specific conflict type

Args:
conflict_type: The conflict type key

Returns:
Dictionary with conflict type information
"""
if conflict_type not in self.conflict_types:
raise ValueError(f"Unknown conflict type: {conflict_type}")

conflict_info = self.conflict_types[conflict_type]

return {
"name": conflict_info.name,
"description": conflict_info.description,
"examples": conflict_info.examples,
"key": conflict_type,
}

def list_all_conflict_types(self) -> Dict[str, Dict[str, Any]]:
"""
Get information about all available conflict types

Returns:
Dictionary mapping conflict type keys to their information
"""
return {
key: {
"name": conflict_type.name,
"description": conflict_type.description,
"examples": conflict_type.examples,
}
for key, conflict_type in self.conflict_types.items()
}

def _load_conflict_types(self, cfg) -> Dict[str, ConflictType]:
"""Load conflict types from configuration"""
conflict_types = {}
for key, config in cfg.doctor.conflict_types.items():
conflict_types[key] = ConflictType(
name=config.name, description=config.description, examples=list(config.examples)
)
return conflict_types

def format_conflict_types_for_prompt(self) -> str:
"""Format conflict types dictionary for use in prompts"""
formatted_types = []
for key, conflict_type in self.conflict_types.items():
examples_str = "\n ".join([f"- {example}" for example in conflict_type.examples])
formatted_types.append(
f"""{key}: {conflict_type.name}
Description: {conflict_type.description}
Examples:
{examples_str}
"""
)
return "\n".join(formatted_types)
8 changes: 8 additions & 0 deletions lib/conflicts/conflicts/core/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Constants for conflict types and agent specializations"""

# Specialized conflict types
PRE_POST_CARE_CONFLICT_TYPE = "pre_post_care_evolution"
TEMPORALITY_CONFLICT_TYPE = "temporality"

# List of all specialized conflict types
SPECIALIZED_CONFLICT_TYPES = [PRE_POST_CARE_CONFLICT_TYPE, TEMPORALITY_CONFLICT_TYPE]
62 changes: 40 additions & 22 deletions lib/conflicts/conflicts/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..agents.moderator_agent import ModeratorAgent
from ..agents.proposition_agent import PropositionAgent
from .base import ConflictDataItem, DatasetManager
from .constants import PRE_POST_CARE_CONFLICT_TYPE, TEMPORALITY_CONFLICT_TYPE
from .data_loader import DataLoader
from .models import DocumentPair, ValidationResult

Expand Down Expand Up @@ -59,8 +60,18 @@ def __init__(self, cfg: DictConfig):

# Initialize agents with shared client and configuration
self.proposition_agent = PropositionAgent(self.client, cfg.model.name, cfg)
self.doctor_agent = DoctorAgent(self.client, cfg.model.name, cfg)

# Initialize specialized doctor agents for the two conflict types
self.doctor_agent_pre_post_care = DoctorAgent(
self.client, cfg.model.name, cfg, conflict_type=PRE_POST_CARE_CONFLICT_TYPE
)
self.doctor_agent_temporality = DoctorAgent(
self.client, cfg.model.name, cfg, conflict_type=TEMPORALITY_CONFLICT_TYPE
)

# Initialize a single editor agent (works for all conflict types)
self.editor_agent = EditorAgent(self.client, cfg.model.name, cfg)

self.moderator_agent = ModeratorAgent(
self.client,
cfg.model.name,
Expand Down Expand Up @@ -251,25 +262,29 @@ def process_document_pair(self, document_pair: DocumentPair) -> Tuple[bool, Dict
result_data["proposition_result"] = proposition_result
result_data["proposition_time"] = proposition_time

# Step 2: Try each conflict type with Doctor Agent choosing proposition pairs
all_attempts = [] # Simplified: single list of all attempts with metadata
best_result = None # Track the best result found so far

# Get all available conflict types
conflict_types = list(self.doctor_agent.list_all_conflict_types().keys())

if not conflict_types:
self.logger.error("No conflict types available - cannot process document pair")
result_data["processing_time"] = time.time() - start_time
return False, result_data

for conflict_type in conflict_types:
self.logger.info(f"Trying conflict type: {conflict_type}")
# Step 2: Process both specialized conflict types for this proposition set
all_attempts = []
best_result = None

# Process both specialized conflict types
conflict_type_configs = [
(
PRE_POST_CARE_CONFLICT_TYPE,
self.doctor_agent_pre_post_care,
),
(
TEMPORALITY_CONFLICT_TYPE,
self.doctor_agent_temporality,
),
]

for conflict_type, doctor_agent in conflict_type_configs:
self.logger.info(f"Processing conflict type: {conflict_type}")

# Doctor Agent chooses proposition pairs for this conflict type
try:
conflict_result, doctor_time = self._execute_agent(
self.doctor_agent,
doctor_agent,
document_pair,
proposition_result[0],
proposition_result[1],
Expand Down Expand Up @@ -410,10 +425,11 @@ def process_document_pair(self, document_pair: DocumentPair) -> Tuple[bool, Dict
if attempt["validation_result"].is_valid
)
)
total_conflict_types = len(set(attempt["conflict_type"] for attempt in all_attempts))

self.logger.info(
f"Pair {pair_id}: {status} - {final_result['conflict_type']} conflict "
f"(best of {successful_conflict_types}/{len(conflict_types)} successful types), "
f"(best of {successful_conflict_types}/{total_conflict_types} successful types), "
f"{proposition_result[0].total_propositions + proposition_result[1].total_propositions}"
f" propositions, {successful_attempts}/{total_attempts} valid"
)
Expand Down Expand Up @@ -503,11 +519,13 @@ def get_pipeline_statistics(self) -> Dict[str, Any]:
"validated_documents": total_validated,
"dataset_statistics": data_stats,
"agents": {
"doctor": {
"name": self.doctor_agent.name,
"conflict_types_available": list(
self.doctor_agent.list_all_conflict_types().keys()
),
"doctor_pre_post_care": {
"name": self.doctor_agent_pre_post_care.name,
"conflict_type": PRE_POST_CARE_CONFLICT_TYPE,
},
"doctor_temporality": {
"name": self.doctor_agent_temporality.name,
"conflict_type": TEMPORALITY_CONFLICT_TYPE,
},
"editor": {"name": self.editor_agent.name},
"moderator": {
Expand Down
Loading