Docgenie-API / api /schemas.py
Ahadhassan-2003
deploy: update HF Space
dc4e6da
"""
Pydantic schemas for API request/response models.
"""
from typing import List, Optional
from pydantic import BaseModel, HttpUrl, Field, field_validator
class PromptParameters(BaseModel):
"""Parameters for customizing the document generation prompt."""
language: str = Field(
default="English",
description="Language for generated documents"
)
doc_type: str = Field(
default="business and administrative",
description="Type of documents to generate (e.g., 'business and administrative', 'receipts', 'forms')"
)
gt_type: str = Field(
default="Multiple questions about each document, with their answers taken **verbatim** from the document.",
description="Description of ground truth type to generate"
)
gt_format: str = Field(
default='{"<Text of question 1>": "<Answer to question 1>", "<Text of question 2>": "<Answer to question 2>", ...}',
description="Format specification for ground truth JSON"
)
num_solutions: int = Field(
default=1,
ge=1,
le=5,
description="Number of document variations to generate (1-5)"
)
# Stage 3: Feature Synthesis parameters
enable_handwriting: bool = Field(
default=False,
description="Enable handwriting generation (requires EC2 handwriting service)"
)
handwriting_ratio: float = Field(
default=0.2,
ge=0.0,
le=1.0,
description="Proportion of text to convert to handwriting (0.0-1.0)"
)
handwriting_apply_ink_filter: bool = Field(
default=True,
description="Apply high-contrast ink filter to handwriting (v16+ feature)"
)
handwriting_enable_enhancements: bool = Field(
default=False,
description="Enable sharpening and contrast boosting (Experimental)"
)
handwriting_num_inference_steps: int = Field(
default=1000,
ge=1,
le=1000,
description="Number of diffusion inference steps (1-1000)"
)
handwriting_writer_ids: List[int] = Field(
default=[404, 347, 156, 253, 354, 166, 320],
description="List of writer style IDs to use for handwriting generation"
)
enable_visual_elements: bool = Field(
default=True,
description="Enable visual element generation (stamps, logos, barcodes)"
)
visual_element_types: List[str] = Field(
default=["stamp", "logo", "figure", "barcode", "photo"],
description="Types of visual elements to generate (stamp, logo, figure, barcode, photo)"
)
barcode_number: Optional[str] = Field(
default=None,
description="Optional fixed number for barcode generation (numeric only)"
)
seed: Optional[int] = Field(
default=None,
description="Random seed for reproducible generation",
examples=[None, 42]
)
# Stage 4: Image Finalization & OCR parameters
enable_ocr: bool = Field(
default=True,
description="Enable OCR on final document images (requires OCR service)"
)
ocr_language: str = Field(
default="en",
description="Language for OCR (e.g., 'en', 'de', 'fr')"
)
# Stage 5: Dataset Packaging parameters
enable_bbox_normalization: bool = Field(
default=True,
description="Normalize bounding boxes to [0,1] scale (Stage 16)"
)
enable_gt_verification: bool = Field(
default=True,
description="Verify and prepare ground truth annotations (Stage 17)"
)
enable_analysis: bool = Field(
default=True,
description="Generate dataset statistics and analysis (Stage 18)"
)
enable_debug_visualization: bool = Field(
default=True,
description="Create debug visualization overlays (Stage 19)"
)
enable_dataset_export: bool = Field(
default=True,
description="Export as msgpack dataset format"
)
dataset_export_format: str = Field(
default="msgpack",
description="Dataset export format: 'msgpack', 'coco', 'huggingface'"
)
output_detail: str = Field(
default="dataset",
description="Output detail level: 'minimal' (final outputs only), 'dataset' (includes individual tokens/elements for ML), 'complete' (all intermediate files and debug info). Warning: 'complete' mode can produce 50+ MB responses."
)
class SeedImage(BaseModel):
"""Seed image URL for document generation."""
url: HttpUrl = Field(
description="URL of the seed image",
default=HttpUrl("https://ocr.space/Content/Images/receipt-ocr-original.webp")
)
class GenerateDocumentRequest(BaseModel):
"""Request schema for document generation endpoint."""
request_id: str = Field(
description="Document request UUID from document_requests table (created by frontend)"
)
google_drive_token: Optional[str] = Field(
default=None,
description="Google Drive OAuth access token. Frontend provides this after OAuth flow (optional)."
)
google_drive_refresh_token: Optional[str] = Field(
default=None,
description="Google Drive refresh token (optional, for automatic token renewal)"
)
seed_images: List[HttpUrl] = Field(
default=[HttpUrl("https://ocr.space/Content/Images/receipt-ocr-original.webp")],
description="List of seed image URLs (1-10 images)"
)
prompt_params: PromptParameters = Field(
default_factory=PromptParameters,
description="Parameters for customizing the generation prompt"
)
@field_validator('seed_images')
@classmethod
def validate_seed_images(cls, v):
if not v:
raise ValueError('At least one seed image is required')
if len(v) < 1:
raise ValueError('At least one seed image is required')
if len(v) > 10:
raise ValueError('Maximum 10 seed images allowed')
return v
class OCRWord(BaseModel):
"""OCR word-level result."""
text: str = Field(description="Recognized text")
confidence: float = Field(ge=0.0, le=1.0, description="OCR confidence score (0-1)")
x: float = Field(description="X coordinate (pixels)")
y: float = Field(description="Y coordinate (pixels)")
width: float = Field(description="Width (pixels)")
height: float = Field(description="Height (pixels)")
class OCRLine(BaseModel):
"""OCR line-level result."""
text: str = Field(description="Recognized text")
confidence: float = Field(ge=0.0, le=1.0, description="OCR confidence score (0-1)")
x: float = Field(description="X coordinate (pixels)")
y: float = Field(description="Y coordinate (pixels)")
width: float = Field(description="Width (pixels)")
height: float = Field(description="Height (pixels)")
words: List[OCRWord] = Field(default_factory=list, description="Words in this line")
class OCRResult(BaseModel):
"""OCR results for a document."""
image_width: int = Field(description="Image width in pixels")
image_height: int = Field(description="Image height in pixels")
words: List[OCRWord] = Field(default_factory=list, description="Word-level OCR results")
lines: List[OCRLine] = Field(default_factory=list, description="Line-level OCR results")
angle: float = Field(default=0.0, description="Detected text orientation angle")
class CostInfo(BaseModel):
"""Cost information for a request (Research Parity)."""
input_tokens: int = Field(description="Number of input tokens")
output_tokens: int = Field(description="Number of output tokens")
cache_creation_tokens: int = Field(default=0, description="Tokens used for cache creation")
cache_read_tokens: int = Field(default=0, description="Tokens read from cache")
cost_usd: float = Field(description="Total cost in USD (with 50% batch discount applied if applicable)")
batch_discount_applied: bool = Field(default=False, description="Whether 50% batch discount was applied")
class NormalizedBBox(BaseModel):
"""Normalized bounding box (Stage 16)."""
text: str = Field(description="Text content")
x0: float = Field(ge=0.0, le=1.0, description="Normalized X min (0-1)")
y0: float = Field(ge=0.0, le=1.0, description="Normalized Y min (0-1)")
x2: float = Field(ge=0.0, le=1.0, description="Normalized X max (0-1)")
y2: float = Field(ge=0.0, le=1.0, description="Normalized Y max (0-1)")
block_no: Optional[int] = Field(default=None, description="Block number")
line_no: Optional[int] = Field(default=None, description="Line number")
word_no: Optional[int] = Field(default=None, description="Word number")
class GTVerificationResult(BaseModel):
"""Ground truth verification results (Stage 17)."""
passed: bool = Field(description="Whether GT verification passed")
skipped: bool = Field(default=False, description="Whether verification was skipped")
confirmed_keys: List[str] = Field(default_factory=list, description="Confirmed GT keys")
similarities: List[float] = Field(default_factory=list, description="Similarity scores")
num_layout_elements: Optional[int] = Field(default=None, description="Number of layout elements")
valid_labels: bool = Field(default=True, description="Whether all labels are valid")
class AnalysisStats(BaseModel):
"""Dataset analysis and statistics (Stage 18)."""
total_documents: int = Field(description="Total documents processed")
valid_documents: int = Field(description="Documents passing all validation")
error_counts: dict = Field(default_factory=dict, description="Error type counts")
has_handwriting: int = Field(default=0, description="Documents with handwriting")
has_visual_elements: int = Field(default=0, description="Documents with visual elements")
has_ocr: int = Field(default=0, description="Documents with OCR results")
multipage_count: int = Field(default=0, description="Multipage documents")
token_usage: Optional[dict] = Field(default=None, description="LLM token usage statistics")
class DebugVisualization(BaseModel):
"""Debug visualization data (Stage 19)."""
bbox_overlay_base64: Optional[str] = Field(default=None, description="Image with bbox overlays (PNG base64)")
visual_elements_overlay_base64: Optional[str] = Field(default=None, description="Image with visual element overlays")
handwriting_overlay_base64: Optional[str] = Field(default=None, description="Image with handwriting overlays")
class DatasetExportInfo(BaseModel):
"""Dataset export metadata."""
format: str = Field(description="Export format (msgpack, coco, etc.)")
num_samples: int = Field(description="Number of samples in export")
output_path: Optional[str] = Field(default=None, description="Path to exported dataset")
msgpack_base64: Optional[str] = Field(default=None, description="Msgpack file as base64 (for small datasets)")
metadata: dict = Field(default_factory=dict, description="Dataset metadata")
class BoundingBox(BaseModel):
"""Bounding box for a text element in the document."""
text: str = Field(description="Text content")
x: float = Field(description="X coordinate (normalized 0-1)")
y: float = Field(description="Y coordinate (normalized 0-1)")
width: float = Field(description="Width (normalized 0-1)")
height: float = Field(description="Height (normalized 0-1)")
page: int = Field(default=0, description="Page number (0-indexed)")
class HandwritingRegion(BaseModel):
"""Information about a handwriting region in the document."""
region_id: str = Field(description="Unique region identifier")
text: str = Field(description="Text content")
author_id: int = Field(ge=0, le=656, description="Author ID for style consistency (0-656)")
bbox: BoundingBox = Field(description="Bounding box of the region")
class VisualElement(BaseModel):
"""Information about a visual element in the document."""
element_id: str = Field(description="Unique element identifier")
element_type: str = Field(description="Type of visual element (stamp, logo, etc.)")
content: Optional[str] = Field(default=None, description="Content (e.g., stamp text)")
bbox: BoundingBox = Field(description="Bounding box of the element")
class DocumentResult(BaseModel):
"""Result for a single generated document."""
document_id: str = Field(description="Unique document identifier")
html: str = Field(description="Generated HTML content")
css: str = Field(description="Extracted CSS styles")
ground_truth: Optional[dict] = Field(
default=None,
description="Ground truth data extracted from the document"
)
pdf_base64: str = Field(description="Base64-encoded PDF document")
bboxes: List[BoundingBox] = Field(
default_factory=list,
description="Bounding boxes for text elements"
)
page_width_mm: float = Field(description="Page width in millimeters")
page_height_mm: float = Field(description="Page height in millimeters")
# Stage 3 additions
handwriting_regions: Optional[List[dict]] = Field(
default=None,
description="Handwriting regions with metadata (if enabled)"
)
visual_elements: Optional[List[dict]] = Field(
default=None,
description="Visual elements with metadata (if enabled)"
)
image_base64: Optional[str] = Field(
default=None,
description="Final rendered image with handwriting/visuals (PNG base64, if Stage 3 enabled)"
)
# Stage 3 individual tokens (dataset/complete output detail levels)
handwriting_token_images: Optional[dict] = Field(
default=None,
description="Individual handwriting token images {hw_id: base64_png} (output_detail: dataset/complete)"
)
visual_element_images: Optional[dict] = Field(
default=None,
description="Individual visual element images {ve_id: base64_png} (output_detail: dataset/complete)"
)
token_mapping: Optional[dict] = Field(
default=None,
description="Token mapping with positions and style IDs (output_detail: dataset/complete)"
)
# Stage 4 additions
ocr_results: Optional[OCRResult] = Field(
default=None,
description="OCR results from final image (if OCR enabled)"
)
# Stage 5 additions
normalized_bboxes_word: Optional[List[NormalizedBBox]] = Field(
default=None,
description="Word-level normalized bounding boxes (if Stage 16 enabled)"
)
normalized_bboxes_segment: Optional[List[NormalizedBBox]] = Field(
default=None,
description="Segment-level normalized bounding boxes (if Stage 16 enabled)"
)
gt_verification: Optional[GTVerificationResult] = Field(
default=None,
description="Ground truth verification results (if Stage 17 enabled)"
)
analysis_stats: Optional[AnalysisStats] = Field(
default=None,
description="Document analysis statistics (if Stage 18 enabled)"
)
debug_visualization: Optional[DebugVisualization] = Field(
default=None,
description="Debug visualization overlays (if Stage 19 enabled)"
)
dataset_export: Optional[DatasetExportInfo] = Field(
default=None,
description="Dataset export information (if export enabled)"
)
cost_info: Optional[CostInfo] = Field(
default=None,
description="Cost information for this document (Research Parity)"
)
class GenerateDocumentResponse(BaseModel):
"""Response schema for document generation endpoint."""
success: bool = Field(description="Whether generation was successful")
message: str = Field(description="Status message")
documents: List[DocumentResult] = Field(
default_factory=list,
description="List of generated documents"
)
total_documents: int = Field(
default=0,
description="Total number of documents generated"
)
total_cost: Optional[CostInfo] = Field(
default=None,
description="Aggregated cost for the entire request"
)
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(default="healthy")
version: str = Field(default="1.0.0")