| | """
|
| | Test cases for the document text extraction system.
|
| | """
|
| |
|
| | import unittest
|
| | import json
|
| | from pathlib import Path
|
| | import tempfile
|
| | import os
|
| |
|
| | from src.data_preparation import DocumentProcessor, NERDatasetCreator
|
| | from src.model import ModelConfig, create_model_and_trainer
|
| | from src.inference import DocumentInference
|
| |
|
| |
|
| | class TestDocumentProcessor(unittest.TestCase):
|
| | """Test cases for document processing."""
|
| |
|
| | def setUp(self):
|
| | """Set up test fixtures."""
|
| | self.processor = DocumentProcessor()
|
| |
|
| | def test_clean_text(self):
|
| | """Test text cleaning functionality."""
|
| | dirty_text = " This is a test text!!! "
|
| | clean_text = self.processor.clean_text(dirty_text)
|
| | self.assertEqual(clean_text, "This is a test text!")
|
| |
|
| | def test_entity_patterns(self):
|
| | """Test entity pattern matching."""
|
| | test_text = "Invoice sent to John Doe on 01/15/2025 Invoice No: INV-1001 Amount: $1,500.00"
|
| |
|
| |
|
| | self.assertIn('NAME', self.processor.entity_patterns)
|
| | self.assertIn('DATE', self.processor.entity_patterns)
|
| | self.assertIn('INVOICE_NO', self.processor.entity_patterns)
|
| | self.assertIn('AMOUNT', self.processor.entity_patterns)
|
| |
|
| |
|
| | class TestNERDatasetCreator(unittest.TestCase):
|
| | """Test cases for NER dataset creation."""
|
| |
|
| | def setUp(self):
|
| | """Set up test fixtures."""
|
| | self.processor = DocumentProcessor()
|
| | self.dataset_creator = NERDatasetCreator(self.processor)
|
| |
|
| | def test_auto_label_text(self):
|
| | """Test automatic text labeling."""
|
| | test_text = "Invoice sent to Robert White on 15/09/2025 Amount: $1,250"
|
| | labeled_tokens = self.dataset_creator.auto_label_text(test_text)
|
| |
|
| |
|
| | self.assertIsInstance(labeled_tokens, list)
|
| | self.assertGreater(len(labeled_tokens), 0)
|
| |
|
| |
|
| | for token, label in labeled_tokens:
|
| | self.assertIsInstance(token, str)
|
| | self.assertIsInstance(label, str)
|
| |
|
| | def test_create_training_example(self):
|
| | """Test training example creation."""
|
| | test_text = "Invoice INV-1001 for $500"
|
| | example = self.dataset_creator.create_training_example(test_text)
|
| |
|
| |
|
| | self.assertIn('tokens', example)
|
| | self.assertIn('labels', example)
|
| | self.assertIn('text', example)
|
| |
|
| |
|
| | self.assertEqual(len(example['tokens']), len(example['labels']))
|
| |
|
| | def test_create_sample_dataset(self):
|
| | """Test sample dataset creation."""
|
| | dataset = self.dataset_creator.create_sample_dataset()
|
| |
|
| |
|
| | self.assertIsInstance(dataset, list)
|
| | self.assertGreater(len(dataset), 0)
|
| |
|
| |
|
| | first_example = dataset[0]
|
| | self.assertIn('tokens', first_example)
|
| | self.assertIn('labels', first_example)
|
| | self.assertIn('text', first_example)
|
| |
|
| |
|
| | class TestModelConfig(unittest.TestCase):
|
| | """Test cases for model configuration."""
|
| |
|
| | def test_default_config(self):
|
| | """Test default configuration creation."""
|
| | config = ModelConfig()
|
| |
|
| |
|
| | self.assertEqual(config.model_name, "distilbert-base-uncased")
|
| | self.assertEqual(config.max_length, 512)
|
| | self.assertEqual(config.batch_size, 16)
|
| |
|
| |
|
| | self.assertIsInstance(config.entity_labels, list)
|
| | self.assertGreater(len(config.entity_labels), 0)
|
| | self.assertIn('O', config.entity_labels)
|
| |
|
| |
|
| | self.assertIsInstance(config.label2id, dict)
|
| | self.assertIsInstance(config.id2label, dict)
|
| | self.assertEqual(len(config.label2id), len(config.entity_labels))
|
| |
|
| | def test_custom_config(self):
|
| | """Test custom configuration."""
|
| | custom_labels = ['O', 'B-TEST', 'I-TEST']
|
| | config = ModelConfig(
|
| | batch_size=32,
|
| | learning_rate=1e-5,
|
| | entity_labels=custom_labels
|
| | )
|
| |
|
| | self.assertEqual(config.batch_size, 32)
|
| | self.assertEqual(config.learning_rate, 1e-5)
|
| | self.assertEqual(config.entity_labels, custom_labels)
|
| | self.assertEqual(config.num_labels, 3)
|
| |
|
| |
|
| | class TestModelCreation(unittest.TestCase):
|
| | """Test cases for model creation."""
|
| |
|
| | def test_create_model_and_trainer(self):
|
| | """Test model and trainer creation."""
|
| | config = ModelConfig(
|
| | batch_size=4,
|
| | num_epochs=1,
|
| | entity_labels=['O', 'B-TEST', 'I-TEST']
|
| | )
|
| |
|
| | model, trainer = create_model_and_trainer(config)
|
| |
|
| |
|
| | self.assertIsNotNone(model)
|
| | self.assertIsNotNone(trainer)
|
| |
|
| |
|
| | self.assertEqual(trainer.config.batch_size, 4)
|
| | self.assertEqual(trainer.config.num_epochs, 1)
|
| |
|
| |
|
| | class TestInference(unittest.TestCase):
|
| | """Test cases for inference pipeline."""
|
| |
|
| | @classmethod
|
| | def setUpClass(cls):
|
| | """Set up class-level fixtures."""
|
| |
|
| |
|
| | cls.model_path = "test_model"
|
| | cls.test_text = "Invoice sent to John Doe on 01/15/2025 Amount: $500.00"
|
| |
|
| | def test_entity_validation(self):
|
| | """Test entity validation patterns."""
|
| |
|
| | test_patterns = {
|
| | 'DATE': ['01/15/2025', '2025-01-15', 'January 15, 2025'],
|
| | 'AMOUNT': ['$500.00', '$1,250.50', '1000.00 USD'],
|
| | 'EMAIL': ['test@email.com', 'user.name@domain.co.uk'],
|
| | 'PHONE': ['(555) 123-4567', '+1-555-987-6543', '555-123-4567']
|
| | }
|
| |
|
| |
|
| | import re
|
| |
|
| | date_pattern = r'\b\d{1,2}[/\-]\d{1,2}[/\-]\d{2,4}\b'
|
| | self.assertTrue(re.search(date_pattern, '01/15/2025'))
|
| |
|
| | amount_pattern = r'\$\s*\d{1,3}(?:,\d{3})*(?:\.\d{2})?'
|
| | self.assertTrue(re.search(amount_pattern, '$1,250.50'))
|
| |
|
| | email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
|
| | self.assertTrue(re.search(email_pattern, 'test@email.com'))
|
| |
|
| |
|
| | class TestEndToEnd(unittest.TestCase):
|
| | """End-to-end integration tests."""
|
| |
|
| | def test_data_preparation_flow(self):
|
| | """Test the complete data preparation flow."""
|
| |
|
| | processor = DocumentProcessor()
|
| | dataset_creator = NERDatasetCreator(processor)
|
| |
|
| |
|
| | dataset = dataset_creator.create_sample_dataset()
|
| |
|
| |
|
| | self.assertIsInstance(dataset, list)
|
| | self.assertGreater(len(dataset), 0)
|
| |
|
| | for example in dataset:
|
| | self.assertIn('tokens', example)
|
| | self.assertIn('labels', example)
|
| | self.assertIn('text', example)
|
| | self.assertEqual(len(example['tokens']), len(example['labels']))
|
| |
|
| | def test_model_config_flow(self):
|
| | """Test model configuration and creation flow."""
|
| |
|
| | config = ModelConfig(batch_size=4, num_epochs=1)
|
| |
|
| |
|
| | model, trainer = create_model_and_trainer(config)
|
| |
|
| |
|
| | self.assertIsNotNone(model)
|
| | self.assertIsNotNone(trainer)
|
| | self.assertEqual(trainer.config.batch_size, 4)
|
| | self.assertEqual(trainer.config.num_epochs, 1)
|
| |
|
| | def test_save_and_load_dataset(self):
|
| | """Test saving and loading dataset."""
|
| |
|
| | processor = DocumentProcessor()
|
| | dataset_creator = NERDatasetCreator(processor)
|
| | dataset = dataset_creator.create_sample_dataset()
|
| |
|
| |
|
| | with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
| | temp_path = f.name
|
| | json.dump(dataset, f, indent=2)
|
| |
|
| | try:
|
| |
|
| | with open(temp_path, 'r') as f:
|
| | loaded_dataset = json.load(f)
|
| |
|
| | self.assertEqual(len(loaded_dataset), len(dataset))
|
| | self.assertEqual(loaded_dataset[0]['text'], dataset[0]['text'])
|
| |
|
| | finally:
|
| |
|
| | os.unlink(temp_path)
|
| |
|
| |
|
| | def run_tests():
|
| | """Run all tests."""
|
| | print("Running Document Text Extraction Tests")
|
| | print("=" * 50)
|
| |
|
| |
|
| | test_suite = unittest.TestSuite()
|
| |
|
| |
|
| | test_classes = [
|
| | TestDocumentProcessor,
|
| | TestNERDatasetCreator,
|
| | TestModelConfig,
|
| | TestModelCreation,
|
| | TestInference,
|
| | TestEndToEnd
|
| | ]
|
| |
|
| | for test_class in test_classes:
|
| | tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
|
| | test_suite.addTests(tests)
|
| |
|
| |
|
| | runner = unittest.TextTestRunner(verbosity=2)
|
| | result = runner.run(test_suite)
|
| |
|
| |
|
| | if result.wasSuccessful():
|
| | print(f"\nAll tests passed! ({result.testsRun} tests)")
|
| | else:
|
| | print(f"\n{len(result.failures)} failures, {len(result.errors)} errors")
|
| |
|
| | if result.failures:
|
| | print("\nFailures:")
|
| | for test, failure in result.failures:
|
| | print(f" {test}: {failure}")
|
| |
|
| | if result.errors:
|
| | print("\nErrors:")
|
| | for test, error in result.errors:
|
| | print(f" {test}: {error}")
|
| |
|
| | return result.wasSuccessful()
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | run_tests() |