Spaces:
Running
Running
import os | |
import base64 | |
import gradio as gr | |
import requests | |
import shutil | |
import time | |
import pymupdf as fitz | |
import logging | |
from mistralai import Mistral, ImageURLChunk | |
from mistralai.models import OCRResponse | |
from typing import Union, List, Tuple, Optional, Dict | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
from concurrent.futures import ThreadPoolExecutor | |
import tempfile | |
# Constants | |
SUPPORTED_IMAGE_TYPES = [".jpg", ".png", ".jpeg"] | |
SUPPORTED_PDF_TYPES = [".pdf"] | |
UPLOAD_FOLDER = "./uploads" | |
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB | |
MAX_PDF_PAGES = 50 | |
# Configuration | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler()] | |
) | |
logger = logging.getLogger(__name__) | |
class OCRProcessor: | |
def __init__(self, api_key: str): | |
if not api_key or not isinstance(api_key, str): | |
raise ValueError("Valid API key must be provided") | |
self.client = Mistral(api_key=api_key) | |
self._validate_client() | |
def _validate_client(self) -> None: | |
try: | |
models = self.client.models.list() | |
if not models: | |
raise ValueError("No models available") | |
except Exception as e: | |
raise ValueError(f"API key validation failed: {str(e)}") | |
def _check_file_size(file_input: Union[str, bytes]) -> None: | |
if isinstance(file_input, str) and os.path.exists(file_input): | |
size = os.path.getsize(file_input) | |
elif hasattr(file_input, 'read'): | |
size = len(file_input.read()) | |
file_input.seek(0) | |
else: | |
size = len(file_input) | |
if size > MAX_FILE_SIZE: | |
raise ValueError(f"File size exceeds {MAX_FILE_SIZE/1024/1024}MB limit") | |
def _save_uploaded_file(file_input: Union[str, bytes], filename: str) -> str: | |
clean_filename = os.path.basename(filename).replace(os.sep, "_") | |
file_path = os.path.join(UPLOAD_FOLDER, f"{int(time.time())}_{clean_filename}") | |
try: | |
if isinstance(file_input, str) and file_input.startswith("http"): | |
response = requests.get(file_input, timeout=30) | |
response.raise_for_status() | |
with open(file_path, 'wb') as f: | |
f.write(response.content) | |
elif isinstance(file_input, str) and os.path.exists(file_input): | |
shutil.copy2(file_input, file_path) | |
else: | |
with open(file_path, 'wb') as f: | |
if hasattr(file_input, 'read'): | |
shutil.copyfileobj(file_input, f) | |
else: | |
f.write(file_input) | |
if not os.path.exists(file_path): | |
raise FileNotFoundError(f"Failed to save file at {file_path}") | |
return file_path | |
except Exception as e: | |
logger.error(f"Error saving file {filename}: {str(e)}") | |
raise | |
def _encode_image(image_path: str) -> str: | |
try: | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode('utf-8') | |
except Exception as e: | |
logger.error(f"Error encoding image {image_path}: {str(e)}") | |
raise ValueError(f"Failed to encode image: {str(e)}") | |
def _pdf_to_images(pdf_path: str) -> List[Tuple[str, str]]: | |
try: | |
pdf_document = fitz.open(pdf_path) | |
if pdf_document.page_count > MAX_PDF_PAGES: | |
pdf_document.close() | |
raise ValueError(f"PDF exceeds maximum page limit of {MAX_PDF_PAGES}") | |
with ThreadPoolExecutor() as executor: | |
image_data = list(executor.map( | |
lambda i: OCRProcessor._convert_page(pdf_path, i), | |
range(pdf_document.page_count) | |
)) | |
pdf_document.close() | |
return [data for data in image_data if data] | |
except Exception as e: | |
logger.error(f"Error converting PDF to images: {str(e)}") | |
return [] | |
def _convert_page(pdf_path: str, page_num: int) -> Tuple[str, str]: | |
try: | |
pdf_document = fitz.open(pdf_path) | |
page = pdf_document[page_num] | |
pix = page.get_pixmap(dpi=150) | |
image_path = os.path.join(UPLOAD_FOLDER, f"page_{page_num + 1}_{int(time.time())}.png") | |
pix.save(image_path) | |
encoded = OCRProcessor._encode_image(image_path) | |
pdf_document.close() | |
return image_path, encoded | |
except Exception as e: | |
logger.error(f"Error converting page {page_num}: {str(e)}") | |
return None, None | |
def _call_ocr_api(self, encoded_image: str) -> OCRResponse: | |
base64_url = f"data:image/png;base64,{encoded_image}" | |
try: | |
logger.info("Calling OCR API") | |
response = self.client.ocr.process( | |
model="mistral-ocr-latest", | |
document=ImageURLChunk(image_url=base64_url), | |
include_image_base64=True | |
) | |
return response | |
except Exception as e: | |
logger.error(f"OCR API call failed: {str(e)}") | |
raise | |
def process_file(self, file: gr.File) -> Tuple[str, str, List[str]]: | |
"""Process uploaded file (image or PDF).""" | |
if not file: | |
return "## No file provided", "", [] | |
file_name = file.name | |
self._check_file_size(file) | |
file_path = self._save_uploaded_file(file, file_name) | |
if file_name.lower().endswith(tuple(SUPPORTED_IMAGE_TYPES)): | |
encoded_image = self._encode_image(file_path) | |
response = self._call_ocr_api(encoded_image) | |
markdown = self._combine_markdown(response) | |
return markdown, file_path, [file_path] | |
elif file_name.lower().endswith('.pdf'): | |
image_data = self._pdf_to_images(file_path) | |
if not image_data: | |
return "## No pages converted from PDF", file_path, [] | |
ocr_results = [] | |
image_paths = [path for path, _ in image_data] | |
for _, encoded in image_data: | |
response = self._call_ocr_api(encoded) | |
markdown = self._combine_markdown(response) | |
ocr_results.append(markdown) | |
return "\n\n".join(ocr_results), file_path, image_paths | |
return "## Unsupported file type", file_path, [] | |
def process_url(self, url: str) -> Tuple[str, str, List[str]]: | |
"""Process URL (image or PDF).""" | |
if not url: | |
return "## No URL provided", "", [] | |
file_name = url.split('/')[-1] or f"file_{int(time.time())}" | |
file_path = self._save_uploaded_file(url, file_name) | |
if file_name.lower().endswith(tuple(SUPPORTED_IMAGE_TYPES)): | |
encoded_image = self._encode_image(file_path) | |
response = self._call_ocr_api(encoded_image) | |
markdown = self._combine_markdown(response) | |
return markdown, url, [file_path] | |
elif file_name.lower().endswith('.pdf'): | |
image_data = self._pdf_to_images(file_path) | |
if not image_data: | |
return "## No pages converted from PDF", url, [] | |
ocr_results = [] | |
image_paths = [path for path, _ in image_data] | |
for _, encoded in image_data: | |
response = self._call_ocr_api(encoded) | |
markdown = self._combine_markdown(response) | |
ocr_results.append(markdown) | |
return "\n\n".join(ocr_results), url, image_paths | |
return "## Unsupported URL content type", url, [] | |
def _combine_markdown(response: OCRResponse) -> str: | |
"""Combine markdown from OCR response.""" | |
markdown_parts = [] | |
for page in response.pages: | |
if not page.markdown.strip(): | |
continue | |
markdown = page.markdown | |
if hasattr(page, 'images') and page.images: | |
for img in page.images: | |
if img.image_base64: | |
markdown = markdown.replace( | |
f"", | |
f"" | |
) | |
markdown_parts.append(markdown) | |
return "\n\n".join(markdown_parts) or "## No text detected" | |
def create_interface(): | |
css = """ | |
.output-markdown {font-size: 14px; max-height: 500px; overflow-y: auto;} | |
.status {color: #666; font-style: italic;} | |
.preview {max-height: 300px;} | |
""" | |
with gr.Blocks(title="Mistral OCR Demo", css=css) as demo: | |
gr.Markdown("# Mistral OCR Demo") | |
gr.Markdown(f""" | |
Process PDFs and images (max {MAX_FILE_SIZE/1024/1024}MB, {MAX_PDF_PAGES} pages for PDFs) via upload or URL. | |
View previews and OCR results with embedded images. | |
Learn more at [Mistral OCR](https://mistral.ai/news/mistral-ocr). | |
""") | |
# API Key Setup | |
with gr.Row(): | |
api_key_input = gr.Textbox(label="Mistral API Key", type="password", placeholder="Enter your API key") | |
set_key_btn = gr.Button("Set API Key", variant="primary") | |
processor_state = gr.State() | |
status = gr.Markdown("Please enter API key", elem_classes="status") | |
def init_processor(key): | |
try: | |
processor = OCRProcessor(key) | |
return processor, "✅ API key validated" | |
except Exception as e: | |
return None, f"❌ Error: {str(e)}" | |
set_key_btn.click(fn=init_processor, inputs=api_key_input, outputs=[processor_state, status]) | |
# File Upload Tab | |
with gr.Tab("Upload File"): | |
with gr.Row(): | |
file_input = gr.File(label="Upload PDF/Image", file_types=SUPPORTED_IMAGE_TYPES + SUPPORTED_PDF_TYPES) | |
file_preview = gr.Gallery(label="Preview", elem_classes="preview") | |
file_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown") | |
file_raw_output = gr.Textbox(label="Raw File Path") | |
file_button = gr.Button("Process", variant="primary") | |
def update_file_preview(file): | |
return [file.name] if file else [] | |
file_input.change(fn=update_file_preview, inputs=file_input, outputs=file_preview) | |
file_button.click( | |
fn=lambda p, f: p.process_file(f) if p else ("## Set API key first", "", []), | |
inputs=[processor_state, file_input], | |
outputs=[file_output, file_raw_output, file_preview] | |
) | |
# URL Tab | |
with gr.Tab("URL Input"): | |
with gr.Row(): | |
url_input = gr.Textbox(label="URL to PDF/Image") | |
url_preview = gr.Gallery(label="Preview", elem_classes="preview") | |
url_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown") | |
url_raw_output = gr.Textbox(label="Raw URL") | |
url_button = gr.Button("Process", variant="primary") | |
def update_url_preview(url): | |
if not url: | |
return [] | |
try: | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') | |
response = requests.get(url, timeout=10) | |
temp_file.write(response.content) | |
temp_file.close() | |
return [temp_file.name] | |
except Exception as e: | |
logger.error(f"URL preview error: {str(e)}") | |
return [] | |
url_input.change(fn=update_url_preview, inputs=url_input, outputs=url_preview) | |
url_button.click( | |
fn=lambda p, u: p.process_url(u) if p else ("## Set API key first", "", []), | |
inputs=[processor_state, url_input], | |
outputs=[url_output, url_raw_output, url_preview] | |
) | |
# Examples | |
gr.Examples( | |
examples=[], | |
inputs=[file_input, url_input] | |
) | |
return demo | |
if __name__ == "__main__": | |
os.environ['START_TIME'] = time.strftime('%Y-%m-%d %H:%M:%S') | |
print(f"===== Application Startup at {os.environ['START_TIME']} =====") | |
create_interface().launch(share=True, max_threads=1) |