| """ |
| Utility Functions for CRNN+CTC Civil Registry OCR |
| Includes CTC decoding, metrics calculation, and helper functions |
| """ |
|
|
| import torch |
| import numpy as np |
| def _editdistance(a, b): |
| """Pure-Python Levenshtein distance β replaces the editdistance C extension.""" |
| m, n = len(a), len(b) |
| dp = list(range(n + 1)) |
| for i in range(1, m + 1): |
| prev, dp[0] = dp[0], i |
| for j in range(1, n + 1): |
| prev, dp[j] = dp[j], prev if a[i-1] == b[j-1] else 1 + min(prev, dp[j], dp[j-1]) |
| return dp[n] |
| from typing import List, Dict, Tuple |
|
|
|
|
| def decode_ctc_predictions(outputs, idx_to_char, method='greedy'): |
| """ |
| Decode CTC predictions to text |
| |
| Args: |
| outputs: Model outputs [seq_len, batch, num_chars] |
| idx_to_char: Dictionary mapping indices to characters |
| method: 'greedy' or 'beam_search' |
| |
| Returns: |
| List of decoded strings |
| """ |
| if method == 'greedy': |
| return greedy_decode(outputs, idx_to_char) |
| elif method == 'beam_search': |
| return beam_search_decode(outputs, idx_to_char) |
| else: |
| raise ValueError(f"Unknown decoding method: {method}") |
|
|
|
|
| def greedy_decode(outputs, idx_to_char): |
| """ |
| Greedy CTC decoding - fast but less accurate |
| """ |
| |
| pred_indices = torch.argmax(outputs, dim=2) |
| pred_indices = pred_indices.permute(1, 0) |
| |
| decoded_texts = [] |
| |
| for sequence in pred_indices: |
| chars = [] |
| prev_idx = -1 |
| |
| for idx in sequence: |
| idx = idx.item() |
| |
| if idx != 0 and idx != prev_idx: |
| if idx in idx_to_char: |
| chars.append(idx_to_char[idx]) |
| prev_idx = idx |
| |
| decoded_texts.append(''.join(chars)) |
| |
| return decoded_texts |
|
|
|
|
| def beam_search_decode(outputs, idx_to_char, beam_width=10): |
| """ |
| Beam search CTC decoding - slower but more accurate. |
| |
| FIXED Bug 6: previous code mixed list-of-chars and string representations. |
| After sorting new_beams (a dict keyed by strings), it did `list(seq)` on the |
| string key β which splits a string like "AB" into ['A','B'] accidentally works |
| for ASCII but is fragile and confusing. Rewritten to use strings throughout: |
| beams are now List[Tuple[str, float]] with the sequence always kept as a plain |
| string, eliminating the list/string ambiguity entirely. |
| """ |
| outputs = torch.nn.functional.softmax(outputs, dim=2) |
| outputs = outputs.permute(1, 0, 2).cpu().numpy() |
|
|
| decoded_texts = [] |
|
|
| for output in outputs: |
| |
| beams: list = [('', 1.0)] |
|
|
| for timestep in output: |
| new_beams: dict = {} |
|
|
| for sequence, prob in beams: |
| for idx, char_prob in enumerate(timestep): |
| if idx == 0: |
| new_seq = sequence |
| elif idx in idx_to_char: |
| char = idx_to_char[idx] |
| |
| if sequence and sequence[-1] == char: |
| new_seq = sequence |
| else: |
| new_seq = sequence + char |
| else: |
| continue |
|
|
| new_prob = prob * char_prob |
| |
| if new_seq in new_beams: |
| new_beams[new_seq] = max(new_beams[new_seq], new_prob) |
| else: |
| new_beams[new_seq] = new_prob |
|
|
| |
| beams = sorted(new_beams.items(), key=lambda x: x[1], reverse=True)[:beam_width] |
|
|
| |
| best_sequence = max(beams, key=lambda x: x[1])[0] |
| decoded_texts.append(best_sequence) |
|
|
| return decoded_texts |
|
|
|
|
| def calculate_cer(predictions: List[str], ground_truths: List[str]) -> float: |
| """ |
| Calculate Character Error Rate (CER) |
| |
| CER = (Substitutions + Deletions + Insertions) / Total Characters |
| """ |
| if len(predictions) != len(ground_truths): |
| raise ValueError("Predictions and ground truths must have same length") |
| |
| total_distance = 0 |
| total_length = 0 |
| |
| for pred, gt in zip(predictions, ground_truths): |
| distance = _editdistance(pred, gt) |
| total_distance += distance |
| total_length += len(gt) |
| |
| cer = (total_distance / total_length * 100) if total_length > 0 else 0 |
| return cer |
|
|
|
|
| def calculate_wer(predictions: List[str], ground_truths: List[str]) -> float: |
| """ |
| Calculate Word Error Rate (WER) |
| |
| WER = (Substitutions + Deletions + Insertions) / Total Words |
| """ |
| if len(predictions) != len(ground_truths): |
| raise ValueError("Predictions and ground truths must have same length") |
| |
| total_distance = 0 |
| total_length = 0 |
| |
| for pred, gt in zip(predictions, ground_truths): |
| pred_words = pred.split() |
| gt_words = gt.split() |
| |
| distance = _editdistance(pred_words, gt_words) |
| total_distance += distance |
| total_length += len(gt_words) |
| |
| wer = (total_distance / total_length * 100) if total_length > 0 else 0 |
| return wer |
|
|
|
|
| def calculate_accuracy(predictions: List[str], ground_truths: List[str]) -> float: |
| """ |
| Calculate exact match accuracy |
| """ |
| if len(predictions) != len(ground_truths): |
| raise ValueError("Predictions and ground truths must have same length") |
| |
| correct = sum(1 for pred, gt in zip(predictions, ground_truths) if pred == gt) |
| accuracy = (correct / len(predictions) * 100) if len(predictions) > 0 else 0 |
| |
| return accuracy |
|
|
|
|
| class EarlyStopping: |
| """ |
| Early stopping to stop training when validation loss stops improving |
| """ |
| |
| def __init__(self, patience=10, min_delta=0.001): |
| self.patience = patience |
| self.min_delta = min_delta |
| self.counter = 0 |
| self.best_loss = None |
| self.early_stop = False |
| |
| def __call__(self, val_loss): |
| if self.best_loss is None: |
| self.best_loss = val_loss |
| elif val_loss > self.best_loss - self.min_delta: |
| self.counter += 1 |
| if self.counter >= self.patience: |
| self.early_stop = True |
| else: |
| self.best_loss = val_loss |
| self.counter = 0 |
| |
| return self.early_stop |
|
|
|
|
| class AverageMeter: |
| """ |
| Computes and stores the average and current value |
| """ |
| |
| def __init__(self): |
| self.reset() |
| |
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
| |
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
|
|
| def calculate_confusion_matrix(predictions: List[str], ground_truths: List[str], char_set: List[str]) -> np.ndarray: |
| """ |
| Calculate character-level confusion matrix |
| |
| Args: |
| predictions: List of predicted strings |
| ground_truths: List of ground truth strings |
| char_set: List of all possible characters |
| |
| Returns: |
| Confusion matrix [num_chars, num_chars] |
| """ |
| char_to_idx = {char: idx for idx, char in enumerate(char_set)} |
| n_chars = len(char_set) |
| |
| confusion = np.zeros((n_chars, n_chars), dtype=np.int64) |
| |
| for pred, gt in zip(predictions, ground_truths): |
| |
| max_len = max(len(pred), len(gt)) |
| pred_padded = pred + ' ' * (max_len - len(pred)) |
| gt_padded = gt + ' ' * (max_len - len(gt)) |
| |
| for p_char, g_char in zip(pred_padded, gt_padded): |
| if p_char in char_to_idx and g_char in char_to_idx: |
| confusion[char_to_idx[g_char], char_to_idx[p_char]] += 1 |
| |
| return confusion |
|
|
|
|
| def extract_form_fields(text: str, form_type: str) -> Dict[str, str]: |
| """ |
| Extract specific fields from recognized text based on form type |
| |
| Args: |
| text: Recognized text |
| form_type: 'form1a', 'form2a', 'form3a', 'form90' |
| |
| Returns: |
| Dictionary of extracted fields |
| """ |
| fields = {} |
| |
| if form_type == 'form1a': |
| |
| |
| fields['type'] = 'Birth Certificate' |
| |
| |
| elif form_type == 'form2a': |
| fields['type'] = 'Death Certificate' |
| |
| elif form_type == 'form3a': |
| fields['type'] = 'Marriage Certificate' |
| |
| elif form_type == 'form90': |
| fields['type'] = 'Marriage License Application' |
| |
| return fields |
|
|
|
|
| def validate_extracted_data(data: Dict[str, str], form_type: str) -> Tuple[bool, List[str]]: |
| """ |
| Validate extracted data for completeness and format |
| |
| Args: |
| data: Extracted data dictionary |
| form_type: Form type |
| |
| Returns: |
| (is_valid, list_of_errors) |
| """ |
| errors = [] |
| |
| |
| required_fields = { |
| 'form1a': ['name', 'date_of_birth', 'place_of_birth'], |
| 'form2a': ['name', 'date_of_death', 'place_of_death'], |
| 'form3a': ['husband_name', 'wife_name', 'date_of_marriage'], |
| 'form90': ['husband_name', 'wife_name', 'date_of_application'] |
| } |
| |
| |
| for field in required_fields.get(form_type, []): |
| if field not in data or not data[field]: |
| errors.append(f"Missing required field: {field}") |
| |
| |
| |
| |
| |
| |
| is_valid = len(errors) == 0 |
| return is_valid, errors |
|
|
|
|
| def load_checkpoint(checkpoint_path, model, optimizer=None, device='cpu'): |
| """ |
| Load model checkpoint |
| |
| Args: |
| checkpoint_path: Path to checkpoint file |
| model: Model instance |
| optimizer: Optimizer instance (optional) |
| device: Device to load to |
| |
| Returns: |
| (model, optimizer, checkpoint_dict) |
| """ |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| |
| model.load_state_dict(checkpoint['model_state_dict']) |
| |
| if optimizer is not None and 'optimizer_state_dict' in checkpoint: |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| |
| print(f"β Loaded checkpoint from {checkpoint_path}") |
| print(f" Epoch: {checkpoint.get('epoch', 'N/A')}") |
| if 'val_cer' in checkpoint: |
| print(f" Val CER : {checkpoint['val_cer']:.4f}%") |
| elif 'val_loss' in checkpoint: |
| print(f" Val Loss : {checkpoint['val_loss']:.4f} (run compare_live_cer.py for true CER)") |
| else: |
| print(f" Val CER : N/A (run compare_live_cer.py for true CER)") |
| |
| return model, optimizer, checkpoint |
|
|
|
|
| def save_predictions_to_file(predictions: List[str], ground_truths: List[str], output_file: str): |
| """ |
| Save predictions and ground truths to file for analysis |
| """ |
| with open(output_file, 'w', encoding='utf-8') as f: |
| f.write("Ground Truth\tPrediction\tMatch\n") |
| f.write("=" * 80 + "\n") |
| |
| for gt, pred in zip(ground_truths, predictions): |
| match = "β" if gt == pred else "β" |
| f.write(f"{gt}\t{pred}\t{match}\n") |
| |
| print(f"β Predictions saved to {output_file}") |
|
|
|
|
| if __name__ == "__main__": |
| |
| print("=" * 60) |
| print("Testing Utility Functions") |
| print("=" * 60) |
| |
| |
| predictions = ["Hello World", "Test", "Sample Text"] |
| ground_truths = ["Hello World", "Tset", "Sample Txt"] |
| |
| cer = calculate_cer(predictions, ground_truths) |
| wer = calculate_wer(predictions, ground_truths) |
| accuracy = calculate_accuracy(predictions, ground_truths) |
| |
| print(f"\nMetrics:") |
| print(f" CER: {cer:.2f}%") |
| print(f" WER: {wer:.2f}%") |
| print(f" Accuracy: {accuracy:.2f}%") |
| |
| |
| print("\nTesting Early Stopping:") |
| early_stopping = EarlyStopping(patience=3, min_delta=0.001) |
| |
| val_losses = [1.0, 0.9, 0.85, 0.84, 0.84, 0.84, 0.84] |
| for epoch, loss in enumerate(val_losses, 1): |
| should_stop = early_stopping(loss) |
| print(f" Epoch {epoch}: Loss = {loss:.2f}, Stop = {should_stop}") |
| if should_stop: |
| break |