ocr / utils.py
hanz245's picture
set up
7111e1a
"""
Utility Functions for CRNN+CTC Civil Registry OCR
Includes CTC decoding, metrics calculation, and helper functions
"""
import torch
import numpy as np
def _editdistance(a, b):
"""Pure-Python Levenshtein distance β€” replaces the editdistance C extension."""
m, n = len(a), len(b)
dp = list(range(n + 1))
for i in range(1, m + 1):
prev, dp[0] = dp[0], i
for j in range(1, n + 1):
prev, dp[j] = dp[j], prev if a[i-1] == b[j-1] else 1 + min(prev, dp[j], dp[j-1])
return dp[n]
from typing import List, Dict, Tuple
def decode_ctc_predictions(outputs, idx_to_char, method='greedy'):
"""
Decode CTC predictions to text
Args:
outputs: Model outputs [seq_len, batch, num_chars]
idx_to_char: Dictionary mapping indices to characters
method: 'greedy' or 'beam_search'
Returns:
List of decoded strings
"""
if method == 'greedy':
return greedy_decode(outputs, idx_to_char)
elif method == 'beam_search':
return beam_search_decode(outputs, idx_to_char)
else:
raise ValueError(f"Unknown decoding method: {method}")
def greedy_decode(outputs, idx_to_char):
"""
Greedy CTC decoding - fast but less accurate
"""
# Get most probable characters
pred_indices = torch.argmax(outputs, dim=2) # [seq_len, batch]
pred_indices = pred_indices.permute(1, 0) # [batch, seq_len]
decoded_texts = []
for sequence in pred_indices:
chars = []
prev_idx = -1
for idx in sequence:
idx = idx.item()
# Skip blank (0) and consecutive duplicates
if idx != 0 and idx != prev_idx:
if idx in idx_to_char:
chars.append(idx_to_char[idx])
prev_idx = idx
decoded_texts.append(''.join(chars))
return decoded_texts
def beam_search_decode(outputs, idx_to_char, beam_width=10):
"""
Beam search CTC decoding - slower but more accurate.
FIXED Bug 6: previous code mixed list-of-chars and string representations.
After sorting new_beams (a dict keyed by strings), it did `list(seq)` on the
string key β€” which splits a string like "AB" into ['A','B'] accidentally works
for ASCII but is fragile and confusing. Rewritten to use strings throughout:
beams are now List[Tuple[str, float]] with the sequence always kept as a plain
string, eliminating the list/string ambiguity entirely.
"""
outputs = torch.nn.functional.softmax(outputs, dim=2)
outputs = outputs.permute(1, 0, 2).cpu().numpy() # [batch, seq_len, num_chars]
decoded_texts = []
for output in outputs:
# Each beam is (sequence_string, cumulative_probability)
beams: list = [('', 1.0)]
for timestep in output:
new_beams: dict = {}
for sequence, prob in beams:
for idx, char_prob in enumerate(timestep):
if idx == 0: # blank token β€” sequence unchanged
new_seq = sequence
elif idx in idx_to_char:
char = idx_to_char[idx]
# CTC rule: merge consecutive duplicate characters
if sequence and sequence[-1] == char:
new_seq = sequence # duplicate β€” stay the same
else:
new_seq = sequence + char # append directly to string
else:
continue
new_prob = prob * char_prob
# Merge beams that produce the same string
if new_seq in new_beams:
new_beams[new_seq] = max(new_beams[new_seq], new_prob)
else:
new_beams[new_seq] = new_prob
# Keep top-k beams; keys are already strings β€” no list() conversion needed
beams = sorted(new_beams.items(), key=lambda x: x[1], reverse=True)[:beam_width]
# Best sequence is the string with highest probability
best_sequence = max(beams, key=lambda x: x[1])[0]
decoded_texts.append(best_sequence)
return decoded_texts
def calculate_cer(predictions: List[str], ground_truths: List[str]) -> float:
"""
Calculate Character Error Rate (CER)
CER = (Substitutions + Deletions + Insertions) / Total Characters
"""
if len(predictions) != len(ground_truths):
raise ValueError("Predictions and ground truths must have same length")
total_distance = 0
total_length = 0
for pred, gt in zip(predictions, ground_truths):
distance = _editdistance(pred, gt)
total_distance += distance
total_length += len(gt)
cer = (total_distance / total_length * 100) if total_length > 0 else 0
return cer
def calculate_wer(predictions: List[str], ground_truths: List[str]) -> float:
"""
Calculate Word Error Rate (WER)
WER = (Substitutions + Deletions + Insertions) / Total Words
"""
if len(predictions) != len(ground_truths):
raise ValueError("Predictions and ground truths must have same length")
total_distance = 0
total_length = 0
for pred, gt in zip(predictions, ground_truths):
pred_words = pred.split()
gt_words = gt.split()
distance = _editdistance(pred_words, gt_words)
total_distance += distance
total_length += len(gt_words)
wer = (total_distance / total_length * 100) if total_length > 0 else 0
return wer
def calculate_accuracy(predictions: List[str], ground_truths: List[str]) -> float:
"""
Calculate exact match accuracy
"""
if len(predictions) != len(ground_truths):
raise ValueError("Predictions and ground truths must have same length")
correct = sum(1 for pred, gt in zip(predictions, ground_truths) if pred == gt)
accuracy = (correct / len(predictions) * 100) if len(predictions) > 0 else 0
return accuracy
class EarlyStopping:
"""
Early stopping to stop training when validation loss stops improving
"""
def __init__(self, patience=10, min_delta=0.001):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = None
self.early_stop = False
def __call__(self, val_loss):
if self.best_loss is None:
self.best_loss = val_loss
elif val_loss > self.best_loss - self.min_delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_loss = val_loss
self.counter = 0
return self.early_stop
class AverageMeter:
"""
Computes and stores the average and current value
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def calculate_confusion_matrix(predictions: List[str], ground_truths: List[str], char_set: List[str]) -> np.ndarray:
"""
Calculate character-level confusion matrix
Args:
predictions: List of predicted strings
ground_truths: List of ground truth strings
char_set: List of all possible characters
Returns:
Confusion matrix [num_chars, num_chars]
"""
char_to_idx = {char: idx for idx, char in enumerate(char_set)}
n_chars = len(char_set)
confusion = np.zeros((n_chars, n_chars), dtype=np.int64)
for pred, gt in zip(predictions, ground_truths):
# Align sequences (simple alignment)
max_len = max(len(pred), len(gt))
pred_padded = pred + ' ' * (max_len - len(pred))
gt_padded = gt + ' ' * (max_len - len(gt))
for p_char, g_char in zip(pred_padded, gt_padded):
if p_char in char_to_idx and g_char in char_to_idx:
confusion[char_to_idx[g_char], char_to_idx[p_char]] += 1
return confusion
def extract_form_fields(text: str, form_type: str) -> Dict[str, str]:
"""
Extract specific fields from recognized text based on form type
Args:
text: Recognized text
form_type: 'form1a', 'form2a', 'form3a', 'form90'
Returns:
Dictionary of extracted fields
"""
fields = {}
if form_type == 'form1a': # Birth Certificate
# Extract common fields (simplified)
# In practice, use NER or regex patterns
fields['type'] = 'Birth Certificate'
# Add more field extraction logic
elif form_type == 'form2a': # Death Certificate
fields['type'] = 'Death Certificate'
elif form_type == 'form3a': # Marriage Certificate
fields['type'] = 'Marriage Certificate'
elif form_type == 'form90': # Marriage License Application
fields['type'] = 'Marriage License Application'
return fields
def validate_extracted_data(data: Dict[str, str], form_type: str) -> Tuple[bool, List[str]]:
"""
Validate extracted data for completeness and format
Args:
data: Extracted data dictionary
form_type: Form type
Returns:
(is_valid, list_of_errors)
"""
errors = []
# Define required fields per form type
required_fields = {
'form1a': ['name', 'date_of_birth', 'place_of_birth'],
'form2a': ['name', 'date_of_death', 'place_of_death'],
'form3a': ['husband_name', 'wife_name', 'date_of_marriage'],
'form90': ['husband_name', 'wife_name', 'date_of_application']
}
# Check required fields
for field in required_fields.get(form_type, []):
if field not in data or not data[field]:
errors.append(f"Missing required field: {field}")
# Additional validation can be added here
# - Date format validation
# - Name format validation
# - etc.
is_valid = len(errors) == 0
return is_valid, errors
def load_checkpoint(checkpoint_path, model, optimizer=None, device='cpu'):
"""
Load model checkpoint
Args:
checkpoint_path: Path to checkpoint file
model: Model instance
optimizer: Optimizer instance (optional)
device: Device to load to
Returns:
(model, optimizer, checkpoint_dict)
"""
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(f"βœ“ Loaded checkpoint from {checkpoint_path}")
print(f" Epoch: {checkpoint.get('epoch', 'N/A')}")
if 'val_cer' in checkpoint:
print(f" Val CER : {checkpoint['val_cer']:.4f}%")
elif 'val_loss' in checkpoint:
print(f" Val Loss : {checkpoint['val_loss']:.4f} (run compare_live_cer.py for true CER)")
else:
print(f" Val CER : N/A (run compare_live_cer.py for true CER)")
return model, optimizer, checkpoint
def save_predictions_to_file(predictions: List[str], ground_truths: List[str], output_file: str):
"""
Save predictions and ground truths to file for analysis
"""
with open(output_file, 'w', encoding='utf-8') as f:
f.write("Ground Truth\tPrediction\tMatch\n")
f.write("=" * 80 + "\n")
for gt, pred in zip(ground_truths, predictions):
match = "βœ“" if gt == pred else "βœ—"
f.write(f"{gt}\t{pred}\t{match}\n")
print(f"βœ“ Predictions saved to {output_file}")
if __name__ == "__main__":
# Test utility functions
print("=" * 60)
print("Testing Utility Functions")
print("=" * 60)
# Test CER calculation
predictions = ["Hello World", "Test", "Sample Text"]
ground_truths = ["Hello World", "Tset", "Sample Txt"]
cer = calculate_cer(predictions, ground_truths)
wer = calculate_wer(predictions, ground_truths)
accuracy = calculate_accuracy(predictions, ground_truths)
print(f"\nMetrics:")
print(f" CER: {cer:.2f}%")
print(f" WER: {wer:.2f}%")
print(f" Accuracy: {accuracy:.2f}%")
# Test early stopping
print("\nTesting Early Stopping:")
early_stopping = EarlyStopping(patience=3, min_delta=0.001)
val_losses = [1.0, 0.9, 0.85, 0.84, 0.84, 0.84, 0.84]
for epoch, loss in enumerate(val_losses, 1):
should_stop = early_stopping(loss)
print(f" Epoch {epoch}: Loss = {loss:.2f}, Stop = {should_stop}")
if should_stop:
break