Vaibuzzz's picture
Upload folder using huggingface_hub
10ff0db verified
"""
Post-Extraction Validation Engine.
Performs programmatic validation checks on extracted financial data
that complement the model's anomaly detection:
1. Arithmetic Consistency — line items sum to subtotal, subtotal + tax = total
2. Required Field Completeness — checks for missing critical fields per doc type
3. Date Format Validation — ensures dates are valid and reasonable
4. Cross-Field Reference Checks — currency consistency, PO references
Usage:
from src.validator import validate_extraction
extra_flags = validate_extraction(json_data)
# Returns list of additional anomaly flag dicts
"""
import re
from datetime import datetime
from typing import List, Optional
def validate_extraction(data: dict) -> List[dict]:
"""
Run all validation checks on extracted document data.
Args:
data: Parsed JSON dict from model output.
Returns:
List of additional anomaly flags not detected by the model.
"""
flags = []
flags.extend(_check_arithmetic(data))
flags.extend(_check_required_fields(data))
flags.extend(_check_date_formats(data))
flags.extend(_check_cross_field(data))
flags.extend(_check_business_logic(data))
return flags
def _check_arithmetic(data: dict) -> List[dict]:
"""Verify that math adds up in the document."""
flags = []
common = data.get("common", {})
line_items = data.get("line_items", [])
type_specific = data.get("type_specific", {})
# Check 1: Line item amounts = quantity × unit_price
for i, item in enumerate(line_items or []):
qty = item.get("quantity")
price = item.get("unit_price")
amount = item.get("amount")
if qty is not None and price is not None and amount is not None:
try:
expected = round(float(qty) * float(price), 2)
actual = round(float(amount), 2)
if abs(expected - actual) > 0.02: # 2 cent tolerance
flags.append({
"category": "arithmetic_error",
"field": f"line_items[{i}].amount",
"severity": "high",
"description": (
f"Line item '{item.get('description', '?')}': "
f"amount {actual} ≠ quantity ({qty}) × unit_price ({price}) = {expected}"
),
"source": "validator",
})
except (ValueError, TypeError):
pass
# Check 2: Line items sum to subtotal
if line_items:
try:
items_sum = round(sum(
float(item.get("amount", 0) or 0) for item in line_items
), 2)
subtotal = type_specific.get("subtotal")
if subtotal is not None:
subtotal = round(float(subtotal), 2)
if abs(items_sum - subtotal) > 0.05:
flags.append({
"category": "arithmetic_error",
"field": "type_specific.subtotal",
"severity": "high",
"description": (
f"Sum of line items ({items_sum}) ≠ subtotal ({subtotal}). "
f"Discrepancy: {abs(items_sum - subtotal):.2f}"
),
"source": "validator",
})
except (ValueError, TypeError):
pass
# Check 3: Subtotal + tax = total
subtotal = type_specific.get("subtotal")
tax = type_specific.get("tax_amount")
total = common.get("total_amount")
if subtotal is not None and tax is not None and total is not None:
try:
expected_total = round(float(subtotal) + float(tax), 2)
actual_total = round(float(total), 2)
if abs(expected_total - actual_total) > 0.05:
flags.append({
"category": "arithmetic_error",
"field": "common.total_amount",
"severity": "high",
"description": (
f"Total ({actual_total}) ≠ subtotal ({subtotal}) + tax ({tax}) = {expected_total}. "
f"Discrepancy: {abs(expected_total - actual_total):.2f}"
),
"source": "validator",
})
except (ValueError, TypeError):
pass
# Check 4: Bank statement — opening + transactions = closing
if common.get("document_type") == "bank_statement":
opening = type_specific.get("opening_balance")
closing = type_specific.get("closing_balance")
if opening is not None and closing is not None and line_items:
try:
txn_sum = sum(float(item.get("amount", 0) or 0) for item in line_items)
expected_closing = round(float(opening) + txn_sum, 2)
actual_closing = round(float(closing), 2)
if abs(expected_closing - actual_closing) > 0.10:
flags.append({
"category": "arithmetic_error",
"field": "type_specific.closing_balance",
"severity": "high",
"description": (
f"Closing balance ({actual_closing}) ≠ opening ({opening}) + "
f"transactions ({txn_sum:.2f}) = {expected_closing}"
),
"source": "validator",
})
except (ValueError, TypeError):
pass
return flags
def _check_required_fields(data: dict) -> List[dict]:
"""Check for missing critical fields based on document type."""
flags = []
common = data.get("common", {})
type_specific = data.get("type_specific", {})
doc_type = common.get("document_type", "")
# Universal required fields
universal = {
"common.date": common.get("date"),
"common.total_amount": common.get("total_amount"),
"common.issuer": common.get("issuer"),
}
for field_path, value in universal.items():
if value is None:
flags.append({
"category": "missing_field",
"field": field_path,
"severity": "medium",
"description": f"Required field '{field_path}' is missing.",
"source": "validator",
})
# Type-specific required fields
required_by_type = {
"invoice": ["invoice_number", "due_date", "subtotal"],
"purchase_order": ["po_number", "delivery_date"],
"receipt": ["receipt_number"],
"bank_statement": ["account_number", "opening_balance", "closing_balance"],
}
for field_name in required_by_type.get(doc_type, []):
if type_specific.get(field_name) is None:
flags.append({
"category": "missing_field",
"field": f"type_specific.{field_name}",
"severity": "low",
"description": f"Expected field '{field_name}' for {doc_type} is missing.",
"source": "validator",
})
# Check issuer has at least a name
issuer = common.get("issuer")
if isinstance(issuer, dict) and not issuer.get("name"):
flags.append({
"category": "missing_field",
"field": "common.issuer.name",
"severity": "medium",
"description": "Issuer entity is present but name is missing.",
"source": "validator",
})
return flags
def _check_date_formats(data: dict) -> List[dict]:
"""Validate date fields are in proper format and reasonable range."""
flags = []
common = data.get("common", {})
type_specific = data.get("type_specific", {})
date_fields = {
"common.date": common.get("date"),
"type_specific.due_date": type_specific.get("due_date"),
"type_specific.delivery_date": type_specific.get("delivery_date"),
}
for field_path, date_str in date_fields.items():
if date_str is None:
continue
if not isinstance(date_str, str):
continue
# Check YYYY-MM-DD format
date_pattern = r'^\d{4}-\d{2}-\d{2}$'
if not re.match(date_pattern, date_str):
flags.append({
"category": "format_anomaly",
"field": field_path,
"severity": "medium",
"description": f"Date '{date_str}' is not in standard YYYY-MM-DD format.",
"source": "validator",
})
continue
# Check if date is actually valid
try:
parsed = datetime.strptime(date_str, "%Y-%m-%d")
# Check reasonable range (not before year 2000 or more than 2 years in future)
now = datetime.now()
if parsed.year < 2000:
flags.append({
"category": "format_anomaly",
"field": field_path,
"severity": "low",
"description": f"Date '{date_str}' is before year 2000, which is unusual.",
"source": "validator",
})
elif parsed > now.replace(year=now.year + 2):
flags.append({
"category": "format_anomaly",
"field": field_path,
"severity": "medium",
"description": f"Date '{date_str}' is more than 2 years in the future.",
"source": "validator",
})
except ValueError:
flags.append({
"category": "format_anomaly",
"field": field_path,
"severity": "medium",
"description": f"Date '{date_str}' is not a valid calendar date.",
"source": "validator",
})
return flags
def _check_cross_field(data: dict) -> List[dict]:
"""Check for inconsistencies between related fields."""
flags = []
common = data.get("common", {})
type_specific = data.get("type_specific", {})
# Check: due_date should be after invoice date
doc_date = common.get("date")
due_date = type_specific.get("due_date")
if doc_date and due_date:
try:
d1 = datetime.strptime(doc_date, "%Y-%m-%d")
d2 = datetime.strptime(due_date, "%Y-%m-%d")
if d2 < d1:
flags.append({
"category": "cross_field",
"field": "type_specific.due_date",
"severity": "high",
"description": f"Due date ({due_date}) is before document date ({doc_date}).",
"source": "validator",
})
except ValueError:
pass
# Check: delivery date should be after PO date
delivery_date = type_specific.get("delivery_date")
if doc_date and delivery_date:
try:
d1 = datetime.strptime(doc_date, "%Y-%m-%d")
d2 = datetime.strptime(delivery_date, "%Y-%m-%d")
if d2 < d1:
flags.append({
"category": "cross_field",
"field": "type_specific.delivery_date",
"severity": "medium",
"description": f"Delivery date ({delivery_date}) is before PO date ({doc_date}).",
"source": "validator",
})
except ValueError:
pass
# Check: negative total amount
total = common.get("total_amount")
if total is not None:
try:
if float(total) < 0:
flags.append({
"category": "cross_field",
"field": "common.total_amount",
"severity": "high",
"description": f"Total amount is negative ({total}), which is unusual.",
"source": "validator",
})
except (ValueError, TypeError):
pass
return flags
def _check_business_logic(data: dict) -> List[dict]:
"""Check for business logic red flags."""
flags = []
common = data.get("common", {})
line_items = data.get("line_items", [])
total = common.get("total_amount")
if total is not None:
try:
total_val = float(total)
# Extremely large amounts
if total_val > 1_000_000:
flags.append({
"category": "business_logic",
"field": "common.total_amount",
"severity": "high",
"description": f"Total amount ${total_val:,.2f} exceeds $1M — requires review.",
"source": "validator",
})
# Perfectly round large amounts (potential fraud indicator)
if total_val >= 10_000 and total_val == int(total_val) and total_val % 1000 == 0:
flags.append({
"category": "business_logic",
"field": "common.total_amount",
"severity": "medium",
"description": (
f"Total amount ${total_val:,.2f} is a perfectly round number — "
f"potential fraud indicator."
),
"source": "validator",
})
except (ValueError, TypeError):
pass
# Check for negative quantities in line items
for i, item in enumerate(line_items or []):
qty = item.get("quantity")
if qty is not None:
try:
if float(qty) < 0:
flags.append({
"category": "format_anomaly",
"field": f"line_items[{i}].quantity",
"severity": "medium",
"description": (
f"Line item '{item.get('description', '?')}' has negative "
f"quantity ({qty})."
),
"source": "validator",
})
except (ValueError, TypeError):
pass
return flags
def merge_flags(model_flags: list, validator_flags: list) -> list:
"""
Merge model-detected and validator-detected flags, removing duplicates.
Deduplication is based on (category, field) pairs.
Args:
model_flags: Flags from the model output.
validator_flags: Flags from programmatic validation.
Returns:
Combined list of unique flags.
"""
seen = set()
merged = []
# Model flags take priority
for flag in model_flags:
key = (flag.get("category", ""), flag.get("field", ""))
if key not in seen:
seen.add(key)
merged.append(flag)
# Add validator flags that aren't duplicates
for flag in validator_flags:
key = (flag.get("category", ""), flag.get("field", ""))
if key not in seen:
seen.add(key)
merged.append(flag)
return merged