import gradio as gr import spaces from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, Qwen2_5_VLForConditionalGeneration from qwen_vl_utils import process_vision_info import torch from PIL import Image import subprocess from datetime import datetime import numpy as np import os from gliner import GLiNER import json import tempfile import zipfile import base64 import io # Initialize GLiNER model gliner_model = GLiNER.from_pretrained("knowledgator/modern-gliner-bi-large-v1.0") DEFAULT_NER_LABELS = "person, organization, location, date, event" # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) # models = { # "Qwen/Qwen2-VL-7B-Instruct": AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2").cuda().eval() # } class TextWithMetadata(list): def __init__(self, *args, **kwargs): super().__init__(*args) self.original_text = kwargs.get('original_text', '') self.entities = kwargs.get('entities', []) def array_to_image_path(image_array): # Convert numpy array to PIL Image img = Image.fromarray(np.uint8(image_array)) img.thumbnail((1024, 1024)) # Generate a unique filename using timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"image_{timestamp}.png" # Save the image img.save(filename) # Get the full path of the saved image full_path = os.path.abspath(filename) return full_path models = { "Qwen/Qwen2.5-VL-7B-Instruct": Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True, torch_dtype="auto").cuda().eval() } processors = { "Qwen/Qwen2.5-VL-7B-Instruct": AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True) } DESCRIPTION = "This demo uses[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)" kwargs = {} kwargs['torch_dtype'] = torch.bfloat16 user_prompt = '<|user|>\n' assistant_prompt = '<|assistant|>\n' prompt_suffix = "<|end|>\n" @spaces.GPU def run_example(image, model_id="Qwen/Qwen2.5-VL-7B-Instruct", run_ner=False, ner_labels=DEFAULT_NER_LABELS): # First get the OCR text text_input = "Convert the image to text." # Print debug info about the image type print(f"Image type: {type(image)}") print(f"Image value: {image}") # Robust handling of image input try: # Handle None or empty input if image is None: raise ValueError("Image input is None") # Handle dictionary input (from API) if isinstance(image, dict): if 'data' in image and isinstance(image['data'], str) and image['data'].startswith('data:image'): # Extract the base64 part base64_data = image['data'].split(',', 1)[1] # Convert base64 to bytes, then to PIL Image image_bytes = base64.b64decode(base64_data) pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Convert to numpy array image = np.array(pil_image) else: raise ValueError(f"Invalid image dictionary format: {image}") # Convert string path to image if needed if isinstance(image, str): pil_image = Image.open(image).convert("RGB") image = np.array(pil_image) # Ensure image is a numpy array if not isinstance(image, np.ndarray): raise ValueError(f"Unsupported image type: {type(image)}") # Convert numpy array to image path image_path = array_to_image_path(image) model = models[model_id] processor = processors[model_id] prompt = f"{user_prompt}<|image_1|>\n{text_input}{prompt_suffix}{assistant_prompt}" pil_image = Image.fromarray(image).convert("RGB") messages = [ { "role": "user", "content": [ { "type": "image", "image": image_path, }, {"type": "text", "text": text_input}, ], } ] # Preparation for inference text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to("cuda") # Inference: Generation of the output generated_ids = model.generate(**inputs, max_new_tokens=1024) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) ocr_text = output_text[0] # If NER is enabled, process the OCR text if run_ner: ner_results = gliner_model.predict_entities( ocr_text, ner_labels.split(","), threshold=0.3 ) # Create a list of tuples (text, label) for highlighting highlighted_text = [] last_end = 0 # Sort entities by start position sorted_entities = sorted(ner_results, key=lambda x: x["start"]) # Process each entity and add non-entity text segments for entity in sorted_entities: # Add non-entity text before the current entity if last_end < entity["start"]: highlighted_text.append((ocr_text[last_end:entity["start"]], None)) # Add the entity text with its label highlighted_text.append(( ocr_text[entity["start"]:entity["end"]], entity["label"] )) last_end = entity["end"] # Add any remaining text after the last entity if last_end < len(ocr_text): highlighted_text.append((ocr_text[last_end:], None)) # Create TextWithMetadata instance with the highlighted text and metadata result = TextWithMetadata(highlighted_text, original_text=ocr_text, entities=ner_results) return result, result # Return twice: once for display, once for state # If NER is disabled, return the text without highlighting result = TextWithMetadata([(ocr_text, None)], original_text=ocr_text, entities=[]) return result, result # Return twice: once for display, once for state except Exception as e: import traceback print(f"Error processing image: {e}") print(traceback.format_exc()) # Return empty result on error result = TextWithMetadata([("Error processing image: " + str(e), None)], original_text="Error: " + str(e), entities=[]) return result, result with gr.Blocks() as demo: # Add state variables to store OCR results ocr_state = gr.State() # gr.Image("Caracal.jpg", interactive=False) with gr.Tab(label="Image Input", elem_classes="tabs"): with gr.Row(): with gr.Column(elem_classes="input-container"): input_img = gr.Image(label="Input Picture", elem_classes="gr-image-input") model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="Qwen/Qwen2.5-VL-7B-Instruct", elem_classes="gr-dropdown") # Add NER controls with gr.Row(): ner_checkbox = gr.Checkbox(label="Run Named Entity Recognition", value=False) ner_labels = gr.Textbox( label="NER Labels (comma-separated)", value=DEFAULT_NER_LABELS, visible=False ) submit_btn = gr.Button(value="Submit", elem_classes="submit-btn") with gr.Column(elem_classes="output-container"): output_text = gr.HighlightedText(label="Output Text", elem_id="output") # Show/hide NER labels based on checkbox ner_checkbox.change( lambda x: gr.update(visible=x), inputs=[ner_checkbox], outputs=[ner_labels] ) # Modify the submit button click handler to update state submit_btn.click( run_example, inputs=[input_img, model_selector, ner_checkbox, ner_labels], outputs=[output_text, ocr_state] # Add ocr_state to outputs ) with gr.Row(): filename = gr.Textbox(label="Save filename (without extension)", placeholder="Enter filename to save") download_btn = gr.Button("Download Image & Text", elem_classes="submit-btn") download_output = gr.File(label="Download") # Modify create_zip to use the state data def create_zip(image, fname, ocr_result): # Validate inputs if not fname or image is None: # Changed the validation check return None try: # Convert numpy array to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) elif not isinstance(image, Image.Image): return None with tempfile.TemporaryDirectory() as temp_dir: # Save image img_path = os.path.join(temp_dir, f"{fname}.png") image.save(img_path) # Use the OCR result from state original_text = ocr_result.original_text if ocr_result else "" entities = ocr_result.entities if ocr_result else [] # Save text txt_path = os.path.join(temp_dir, f"{fname}.txt") with open(txt_path, 'w', encoding='utf-8') as f: f.write(original_text) # Create JSON with text and entities json_data = { "text": original_text, "entities": entities, "image_file": f"{fname}.png" } # Save JSON json_path = os.path.join(temp_dir, f"{fname}.json") with open(json_path, 'w', encoding='utf-8') as f: json.dump(json_data, f, indent=2, ensure_ascii=False) # Create zip file output_dir = "downloads" os.makedirs(output_dir, exist_ok=True) zip_path = os.path.join(output_dir, f"{fname}.zip") with zipfile.ZipFile(zip_path, 'w') as zipf: zipf.write(img_path, os.path.basename(img_path)) zipf.write(txt_path, os.path.basename(txt_path)) zipf.write(json_path, os.path.basename(json_path)) return zip_path except Exception as e: print(f"Error creating zip: {str(e)}") return None # Update the download button click handler to include state download_btn.click( create_zip, inputs=[input_img, filename, ocr_state], outputs=[download_output] ) demo.queue(api_open=False) demo.launch(debug=True)