import gradio as gr from byaldi import RAGMultiModalModel from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info import torch from PIL import Image import os import traceback import spaces import re # Check if CUDA is available device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load the Byaldi and Qwen2-VL models rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali") # Byaldi model qwen_model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16 ).to(device) # Move Qwen2-VL to GPU # Processor for Qwen2-VL processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True) # Function for OCR and text extraction @spaces.GPU(duration=120) # Increased GPU duration to 120 seconds def ocr_and_extract(image): try: # Save the uploaded image temporarily temp_image_path = "temp_image.jpg" image.save(temp_image_path) # Index the image with Byaldi, and force overwrite of the existing index rag_model.index( input_path=temp_image_path, index_name="image_index", # Reuse the same index store_collection_with_index=False, overwrite=True # Overwrite the index for every new image ) # Perform the search query on the indexed image results = rag_model.search("", k=1) # Prepare the input for Qwen2-VL image_data = Image.open(temp_image_path) messages = [ { "role": "user", "content": [ {"type": "image", "image": image_data}, ], } ] # Process the message and prepare for Qwen2-VL text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, _ = process_vision_info(messages) # Move the image inputs and processor outputs to CUDA inputs = processor( text=[text_input], images=image_inputs, padding=True, return_tensors="pt", ).to(device) # Generate the output with Qwen2-VL generated_ids = qwen_model.generate(**inputs, max_new_tokens=50) output_text = processor.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) # Filter out "You are a helpful assistant" and "assistant" labels filtered_output = [line for line in output_text[0].split("\n") if not any(kw in line.lower() for kw in ["you are a helpful assistant", "assistant", "user", "system"])] extracted_text = "\n".join(filtered_output).strip() # Clean up the temporary file os.remove(temp_image_path) return extracted_text except Exception as e: error_message = str(e) traceback.print_exc() return f"Error: {error_message}" def search_keywords(extracted_text, keywords): if not extracted_text: return "No text extracted yet. Please upload an image." # Highlight matching keywords in the extracted text highlighted_text = extracted_text for keyword in keywords.split(): highlighted_text = re.sub(f"({re.escape(keyword)})", r"\1", highlighted_text, flags=re.IGNORECASE) # Return the highlighted text as HTML return f"