import gradio as gr from PIL import Image import os import json import time import torch from transformers import MllamaForConditionalGeneration, AutoProcessor import spaces ckpt = "unsloth/Llama-3.2-11B-Vision-Instruct" device = "cuda" if torch.cuda.is_available() else "cpu" # Chargement du modèle et processeur model = MllamaForConditionalGeneration.from_pretrained( ckpt, torch_dtype=torch.bfloat16 if device=="cuda" else torch.float32 ).to(device) processor = AutoProcessor.from_pretrained(ckpt) SAVE_DIR = "corrections" os.makedirs(SAVE_DIR, exist_ok=True) @spaces.GPU def ocr_on_image(image): prompt = ( "Output ONLY the raw text exactly as it appears in the image. Do not add anything.\n\n" "The image may contain both handwritten and printed text in French and/or English, including punctuation and underscores.\n\n" "Your task: Transcribe all visible text exactly, preserving:\n" "- All characters, accents, punctuation, spacing, and line breaks.\n" "- The original reading order and layout, including tables and forms if present.\n\n" "Rules:\n" "- Do NOT add any explanations, summaries, comments, or extra text.\n" "- Do NOT duplicate any content.\n" "- Do NOT indicate blank space.\n" "- Do NOT separate handwritten and printed text.\n" "- Do NOT confuse '.' (a period) with '|' (a border).\n\n" "Only extract the text that is actually visible in the image, and nothing else." ) prompt2 =( "Extract all visible text from the image, including both handwritten and printed content." "Do not translate the text — preserve the original language exactly as it appears." "Return only the extracted text, with no explanation, no formatting, and no additions." ) prompt3 = ( "Output ONLY the raw text as it appears in the image, nothing else." "You have an image containing both handwritten and printed text in French and/or English, and alsos punctuation and underscores.\n" "Your task: transcribe EXACTLY all visible text, preserving all characters, accents, punctuation, spacing, and line breaks.\n" "Include tables and forms clearly if present.\n" "Do NOT add any explanations, comments, summaries, or extra text.\n" "Check the output first to not duplicate results." "Preserve the original reading order, including line breaks and the natural layout of tables or forms. Output the text exactly as it appears visually, maintaining the structure." "Don't indicate blank space." "Don't separate handwritten and printex text." "DO NOT confuse between '.' a point and '|' a boder" "Extract only the raw text with and do not add any comment" "Extract the content ligne by ligne" ) messages = [{"role": "user", "content": [{"type": "text", "text": "Extract handwritten text from the image and output only the extracted text without any additional description or commentary in output"}, {"type": "image"}]}] texts = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=texts, images=[image], return_tensors="pt").to(device) outputs = model.generate(**inputs, max_new_tokens=250) result = processor.decode(outputs[0], skip_special_tokens=True) # Nettoyage simple if "assistant" in result.lower(): result = result[result.lower().find("assistant") + len("assistant"):].strip() result = result.replace("user", "").replace(prompt, "").strip() return result def batch_ocr(images): if not images: return [], "Aucune image uploadée." results = [] status_text = f"Traitement de {len(images)} image(s)...\n" for i, img_file in enumerate(images): try: pil_img = Image.open(img_file.name).convert("RGB") text = ocr_on_image(pil_img) results.append({ "image": pil_img, "filepath": img_file.name, "ocr_text": text, "corrected_text": text }) status_text += f"Image {i+1}: ✓ Texte extrait\n" except Exception as e: status_text += f"Image {i+1}: ❌ Erreur: {str(e)}\n" return results, status_text def save_all_corrections(data_list, *corrections): if not data_list: return "Aucune donnée à sauvegarder." # Mettre à jour les corrections avec les textes modifiés for i, correction in enumerate(corrections): if i < len(data_list) and correction.strip(): data_list[i]["corrected_text"] = correction timestamp = int(time.time()) saved_files = [] for i, data in enumerate(data_list): img_path = f"{SAVE_DIR}/image_{timestamp}_{i}.png" json_path = f"{SAVE_DIR}/correction_{timestamp}_{i}.jsonl" data["image"].save(img_path) entry = { "image_path": img_path, "ocr_text": data["ocr_text"], "corrected_text": data["corrected_text"] } with open(json_path, "a", encoding="utf-8") as f: f.write(json.dumps(entry, ensure_ascii=False) + "\n") saved_files.append(json_path) return f"✅ {len(saved_files)} correction(s) sauvegardée(s) dans le dossier '{SAVE_DIR}'." # Interface Gradio simplifiée with gr.Blocks(title="OCR avec Llama Vision", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🔍 OCR Multi-Images avec Correction Manuelle") gr.Markdown("Uploadez vos images et extrayez le texte automatiquement, puis corrigez si nécessaire.") with gr.Row(): uploaded = gr.Files( file_types=[".png", ".jpg", ".jpeg", ".tif"], label="📁 Uploader plusieurs images", file_count="multiple" ) btn_ocr = gr.Button("🚀 Extraire le texte OCR", variant="primary", size="lg") status = gr.Textbox(label="📊 Status", lines=3, visible=False) # Conteneurs pour les résultats (fixes, pas dynamiques) results_data = gr.State([]) with gr.Column(visible=False) as results_section: gr.Markdown("## 📝 Résultats OCR - Vous pouvez modifier le texte ci-dessous") # Interface fixe pour jusqu'à 5 images (ajustez selon vos besoins) image_components = [] text_components = [] for i in range(5): # Maximum 5 images with gr.Row(visible=False) as row: with gr.Column(scale=1): img_comp = gr.Image(label=f"Image {i+1}", height=300) image_components.append((row, img_comp)) with gr.Column(scale=2): txt_comp = gr.Textbox( label=f"Texte extrait - Image {i+1}", lines=10, placeholder="Le texte extrait apparaîtra ici..." ) text_components.append(txt_comp) btn_save = gr.Button("💾 Sauvegarder toutes les corrections", variant="secondary", size="lg") save_status = gr.Textbox(label="💾 Status de sauvegarde", visible=False) def process_images(images): if not images: return ( gr.update(visible=True, value="❌ Aucune image uploadée."), gr.update(visible=False), gr.update(visible=False), [], *[gr.update(visible=False) for _ in range(5)], *[gr.update(value="") for _ in range(5)] ) results, status_text = batch_ocr(images) # Mise à jour des composants d'image et de texte image_updates = [] text_updates = [] for i in range(5): if i < len(results): # Montrer l'image et le texte image_updates.append(gr.update(visible=True)) image_updates.append(gr.update(value=results[i]["image"])) text_updates.append(gr.update(value=results[i]["ocr_text"])) else: # Cacher les composants non utilisés image_updates.append(gr.update(visible=False)) image_updates.append(gr.update(value=None)) text_updates.append(gr.update(value="")) return ( gr.update(visible=True, value=status_text), gr.update(visible=True), gr.update(visible=True), results, *image_updates, *text_updates ) # Préparer les outputs pour le clic image_outputs = [] for row, img in image_components: image_outputs.extend([row, img]) btn_ocr.click( process_images, inputs=[uploaded], outputs=[ status, results_section, save_status, results_data, *image_outputs, *text_components ] ) btn_save.click( save_all_corrections, inputs=[results_data] + text_components, outputs=save_status ) if __name__ == "__main__": demo.launch(debug=True)