Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 5,459 Bytes
fab8051 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
#!/usr/bin/env python3
"""End-to-end validation tests for radiology report structuring.
This module provides focused validation tests that verify the complete
RadiologyReportStructurer pipeline by comparing actual processing
results against known good cached outputs.
Typical usage example:
# Run with unittest (built-in)
python test_validation.py
python -m unittest test_validation.py -v
# Run with pytest (recommended for CI/CD)
pytest test_validation.py -v
"""
import json
import os
import sys
import unittest
from typing import Any
from unittest import mock
# Add the current directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from structure_report import RadiologyReportStructurer
class TestRadiologyReportEndToEnd(unittest.TestCase):
"""End-to-end tests for complete RadiologyReportStructurer pipeline."""
cache_file: str
sample_data: dict[str, Any]
structurer: RadiologyReportStructurer
@classmethod
def setUpClass(cls):
cls.cache_file = 'cache/sample_cache.json'
cls.sample_data = cls._load_sample_cache()
cls.structurer = RadiologyReportStructurer(
api_key='test_key', model_id='gemini-2.5-flash'
)
@classmethod
def _load_sample_cache(cls) -> dict[str, Any]:
if not os.path.exists(cls.cache_file):
raise FileNotFoundError(f'Sample cache file not found: {cls.cache_file}')
with open(cls.cache_file, 'r', encoding='utf-8') as f:
return json.load(f)
def _validate_response_structure(self, response: dict[str, Any]) -> None:
self.assertIn('segments', response)
self.assertIn('text', response)
self.assertIsInstance(response['segments'], list)
self.assertIsInstance(response['text'], str)
def _validate_successful_response(self, response: dict[str, Any]) -> None:
self._validate_response_structure(response)
self.assertGreater(len(response['segments']), 0)
self.assertGreater(len(response['text']), 0)
for segment in response['segments']:
self._validate_segment_structure(segment)
def _validate_segment_structure(self, segment: dict[str, Any]) -> None:
required_fields = ['type', 'label', 'content', 'intervals']
for field in required_fields:
self.assertIn(field, segment)
valid_types = ['prefix', 'body', 'suffix']
self.assertIn(segment['type'], valid_types)
if segment['intervals']:
for interval in segment['intervals']:
self.assertIn('startPos', interval)
self.assertIn('endPos', interval)
self.assertGreaterEqual(interval['startPos'], 0)
self.assertGreater(interval['endPos'], interval['startPos'])
@mock.patch('structure_report.lx.extract')
def test_end_to_end_processing_pipeline(self, mock_extract):
mock_result = mock.MagicMock()
mock_result.extractions = []
mock_extract.return_value = mock_result
input_text = 'EXAMINATION: Chest CT\n\nFINDINGS: Normal lungs.\n\nIMPRESSION: No acute findings.'
response = self.structurer.predict(input_text)
self._validate_response_structure(response)
mock_extract.assert_called_once()
call_args = mock_extract.call_args
self.assertEqual(call_args[1]['text_or_documents'], input_text)
self.assertEqual(call_args[1]['model_id'], 'gemini-2.5-flash')
def test_all_cached_samples_validation(self):
self.assertGreater(len(self.sample_data), 0, 'No samples found in cache')
for sample_key, sample in self.sample_data.items():
with self.subTest(sample=sample_key):
self._validate_successful_response(sample)
def test_error_handling_with_invalid_input(self):
with self.assertRaises(ValueError) as context:
self.structurer.predict('')
self.assertIn('Report text cannot be empty', str(context.exception))
with self.assertRaises(ValueError):
self.structurer.predict(' \n\t ')
def test_error_handling_with_no_api_key(self):
error_structurer = RadiologyReportStructurer(api_key=None)
response = error_structurer.predict('EXAMINATION: Test')
self._validate_response_structure(response)
self.assertEqual(len(response['segments']), 0)
self.assertIn('Error processing report', response['text'])
def test_patch_initialization_on_first_use(self):
new_structurer = RadiologyReportStructurer()
self.assertFalse(new_structurer._patches_initialized)
new_structurer._ensure_patches_initialized()
self.assertTrue(new_structurer._patches_initialized)
def test_section_mapping_core_functionality(self):
self.assertEqual(
self.structurer._map_section('findings_prefix'),
self.structurer._map_section('findings_prefix'),
)
self.assertIsNone(self.structurer._map_section('invalid_section'))
self.assertIsNone(self.structurer._map_section(''))
def test_exam_prefix_stripping(self):
self.assertEqual(
self.structurer._strip_exam_prefix('EXAMINATION: Chest CT'), 'Chest CT'
)
self.assertEqual(
self.structurer._strip_exam_prefix('Normal findings'), 'Normal findings'
)
if __name__ == '__main__':
unittest.main(verbosity=2)
|