Spaces:
Sleeping
Sleeping
| """ | |
| Pydantic models for financial document extraction. | |
| Option C: Common core + type-specific extensions. | |
| """ | |
| from pydantic import BaseModel, Field | |
| from typing import Optional, List, Literal | |
| from enum import Enum | |
| class AnomalyCategory(str, Enum): | |
| """Categories of anomalies the model can detect.""" | |
| ARITHMETIC = "arithmetic_error" | |
| MISSING_FIELD = "missing_field" | |
| FORMAT = "format_anomaly" | |
| BUSINESS_LOGIC = "business_logic" | |
| CROSS_FIELD = "cross_field" | |
| class Severity(str, Enum): | |
| """Severity levels for detected anomalies.""" | |
| LOW = "low" | |
| MEDIUM = "medium" | |
| HIGH = "high" | |
| class Party(BaseModel): | |
| """Represents an entity (vendor, buyer, etc.).""" | |
| name: Optional[str] = None | |
| address: Optional[str] = None | |
| class CommonFields(BaseModel): | |
| """Fields shared across all financial document types.""" | |
| document_type: Literal["invoice", "purchase_order", "receipt", "bank_statement"] | |
| date: Optional[str] = None | |
| issuer: Optional[Party] = None | |
| recipient: Optional[Party] = None | |
| total_amount: Optional[float] = None | |
| currency: Optional[str] = "USD" | |
| class LineItem(BaseModel): | |
| """A single line item in a financial document.""" | |
| description: str | |
| quantity: Optional[float] = None | |
| unit_price: Optional[float] = None | |
| amount: Optional[float] = None | |
| # === Type-Specific Extensions === | |
| class InvoiceFields(BaseModel): | |
| """Fields specific to invoices.""" | |
| invoice_number: Optional[str] = None | |
| due_date: Optional[str] = None | |
| payment_terms: Optional[str] = None | |
| tax_amount: Optional[float] = None | |
| subtotal: Optional[float] = None | |
| class PurchaseOrderFields(BaseModel): | |
| """Fields specific to purchase orders.""" | |
| po_number: Optional[str] = None | |
| delivery_date: Optional[str] = None | |
| shipping_address: Optional[str] = None | |
| referenced_invoice: Optional[str] = None | |
| class ReceiptFields(BaseModel): | |
| """Fields specific to receipts.""" | |
| receipt_number: Optional[str] = None | |
| payment_method: Optional[str] = None | |
| store_location: Optional[str] = None | |
| cashier: Optional[str] = None | |
| class BankStatementFields(BaseModel): | |
| """Fields specific to bank statements.""" | |
| account_number: Optional[str] = None | |
| statement_period: Optional[str] = None | |
| opening_balance: Optional[float] = None | |
| closing_balance: Optional[float] = None | |
| class AnomalyFlag(BaseModel): | |
| """A single anomaly detected in the document.""" | |
| category: AnomalyCategory | |
| field: str | |
| severity: Severity | |
| description: str | |
| class DocumentExtraction(BaseModel): | |
| """ | |
| Top-level extraction result. | |
| Schema Option C: Common core + type-specific extensions. | |
| """ | |
| common: CommonFields | |
| line_items: Optional[List[LineItem]] = [] | |
| type_specific: dict = {} # Flexible dict to handle all doc types | |
| flags: List[AnomalyFlag] = [] | |
| confidence_score: float = Field(ge=0.0, le=1.0, default=0.95) | |
| def has_anomalies(self) -> bool: | |
| """Check if any anomalies were detected.""" | |
| return len(self.flags) > 0 | |
| def high_severity_flags(self) -> List[AnomalyFlag]: | |
| """Return only high-severity anomalies.""" | |
| return [f for f in self.flags if f.severity == Severity.HIGH] | |