diff --git a/README.md b/README.md index c45a6ef..3c9d80a 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,9 @@ for result in results["completed"]: print(f" Vendor: {analysis.vendor} (page: {citations.get("vendor").page})") print(f" Total: ${analysis.total_amount:.2f} (page: {citations.get("total_amount").page})") print(f" Status: {analysis.payment_status} (page: {citations.get("payment_status").page})") + + # Save each result to JSON file + result.save_to_json(f"./invoice_results/{result.job_id}.json") # Process failed/cancelled results for result in results["failed"]: @@ -121,7 +124,6 @@ for result in results["cancelled"]: print(f"\nJob {result.job_id} was cancelled: {result.error}") ``` - ## Interactive Progress Display Batchata provides an interactive real-time progress display when using `print_status=True`: diff --git a/batchata/core/job_result.py b/batchata/core/job_result.py index 3f57ba1..eaeba3d 100644 --- a/batchata/core/job_result.py +++ b/batchata/core/job_result.py @@ -1,6 +1,6 @@ """JobResult data model.""" -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel @@ -88,6 +88,20 @@ def to_dict(self) -> Dict[str, Any]: "batch_id": self.batch_id } + def save_to_json(self, filepath: str, indent: int = 2) -> None: + """Save JobResult to JSON file. + + Args: + filepath: Path to save the JSON file + indent: JSON indentation (default: 2) + """ + import json + from pathlib import Path + + Path(filepath).parent.mkdir(parents=True, exist_ok=True) + with open(filepath, 'w') as f: + json.dump(self.to_dict(), f, indent=indent) + @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'JobResult': """Deserialize from state.""" diff --git a/tests/core/test_job_result.py b/tests/core/test_job_result.py index 77509fa..e7eb687 100644 --- a/tests/core/test_job_result.py +++ b/tests/core/test_job_result.py @@ -316,4 +316,63 @@ def test_citation_mappings_json_serialization(self): assert len(restored.citations) == 2 assert len(restored.citation_mappings) == 3 assert len(restored.citation_mappings['cap_rate']) == 2 - assert len(restored.citation_mappings['occupancy']) == 1 \ No newline at end of file + assert len(restored.citation_mappings['occupancy']) == 1 + + def test_save_to_json(self, tmp_path): + """Test that save_to_json() correctly saves JobResult to a JSON file.""" + # Create a JobResult with citations and citation_mappings + citations = [ + Citation( + text='Test citation text', + source='test.pdf', + page=1, + metadata={'type': 'page_location', 'document_index': 0} + ) + ] + + citation_mappings = { + 'test_field': citations + } + + result = JobResult( + job_id="test-save-json", + raw_response="Test response", + parsed_response={'test_field': 'test_value'}, + citations=citations, + citation_mappings=citation_mappings, + input_tokens=100, + output_tokens=50, + cost_usd=0.05 + ) + + # Save to JSON file + json_file = tmp_path / "subdir" / "test_result.json" + result.save_to_json(str(json_file)) + + # Verify file was created + assert json_file.exists() + + # Verify content is correct by loading and comparing + import json + with open(json_file, 'r') as f: + saved_data = json.load(f) + + # Should match the result of to_dict() + expected_data = result.to_dict() + assert saved_data == expected_data + + # Verify specific fields + assert saved_data['job_id'] == 'test-save-json' + assert saved_data['input_tokens'] == 100 + assert saved_data['output_tokens'] == 50 + assert saved_data['cost_usd'] == 0.05 + + # Verify citations are properly serialized (not Citation objects) + assert isinstance(saved_data['citations'][0], dict) + assert saved_data['citations'][0]['text'] == 'Test citation text' + assert saved_data['citations'][0]['source'] == 'test.pdf' + assert saved_data['citations'][0]['page'] == 1 + + # Verify citation_mappings are properly serialized + assert isinstance(saved_data['citation_mappings']['test_field'][0], dict) + assert saved_data['citation_mappings']['test_field'][0]['text'] == 'Test citation text' \ No newline at end of file diff --git a/uv.lock b/uv.lock index 4165959..65688ec 100644 --- a/uv.lock +++ b/uv.lock @@ -130,7 +130,7 @@ wheels = [ [[package]] name = "batchata" -version = "0.4.3" +version = "0.4.5" source = { editable = "." } dependencies = [ { name = "anthropic" },