File size: 4,615 Bytes
2f3144c
 
 
 
 
 
 
 
40f7360
d85fa29
2f3144c
40f7360
 
 
 
 
 
2f3144c
3ef82d2
40f7360
 
 
3ef82d2
2f3144c
db3dd58
40f7360
 
2f3144c
d85fa29
2f3144c
 
 
40f7360
d85fa29
 
40f7360
d85fa29
40f7360
d85fa29
 
 
40f7360
2f3144c
d85fa29
2f3144c
d85fa29
2f3144c
d85fa29
 
 
 
 
 
2f3144c
 
40f7360
2f3144c
 
d85fa29
40f7360
d85fa29
 
 
 
 
40f7360
d85fa29
 
2f3144c
40f7360
 
 
 
 
 
 
 
 
2be66f7
 
cacc570
2be66f7
 
d85fa29
2be66f7
d85fa29
 
db3dd58
d85fa29
cacc570
40f7360
 
 
 
 
 
86bd82f
 
40f7360
 
 
b195407
 
 
40f7360
 
 
db3dd58
40f7360
 
 
 
 
 
 
 
 
 
 
 
db3dd58
40f7360
21264d2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from PIL import Image
import os
import traceback
import spaces
import re

# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load the Byaldi and Qwen2-VL models
rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali")  # Byaldi model
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
).to(device)  # Move Qwen2-VL to GPU

# Processor for Qwen2-VL
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)

# Function for OCR and text extraction
@spaces.GPU(duration=120)  # Increased GPU duration to 120 seconds
def ocr_and_extract(image):
    try:
        # Save the uploaded image temporarily
        temp_image_path = "temp_image.jpg"
        image.save(temp_image_path)

        # Index the image with Byaldi, and force overwrite of the existing index
        rag_model.index(
            input_path=temp_image_path,
            index_name="image_index",  # Reuse the same index
            store_collection_with_index=False,
            overwrite=True  # Overwrite the index for every new image
        )

        # Perform the search query on the indexed image
        results = rag_model.search("", k=1)

        # Prepare the input for Qwen2-VL
        image_data = Image.open(temp_image_path)

        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image_data},
                ],
            }
        ]

        # Process the message and prepare for Qwen2-VL
        text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, _ = process_vision_info(messages)

        # Move the image inputs and processor outputs to CUDA
        inputs = processor(
            text=[text_input],
            images=image_inputs,
            padding=True,
            return_tensors="pt",
        ).to(device)

        # Generate the output with Qwen2-VL
        generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
        output_text = processor.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )

        # Filter out "You are a helpful assistant" and "assistant" labels
        filtered_output = [line for line in output_text[0].split("\n") if not any(kw in line.lower() for kw in ["you are a helpful assistant", "assistant", "user", "system"])]
        extracted_text = "\n".join(filtered_output).strip()

        # Clean up the temporary file
        os.remove(temp_image_path)

        return extracted_text

    except Exception as e:
        error_message = str(e)
        traceback.print_exc()
        return f"Error: {error_message}"

def search_keywords(extracted_text, keywords):
    if not extracted_text:
        return "No text extracted yet. Please upload an image."

    # Highlight matching keywords in the extracted text
    highlighted_text = extracted_text
    for keyword in keywords.split():
        highlighted_text = re.sub(f"({re.escape(keyword)})", r"<mark>\1</mark>", highlighted_text, flags=re.IGNORECASE)

    # Return the highlighted text as HTML
    return f"<div style='white-space: pre-wrap'>{highlighted_text}</div>"

# Gradio interface for image input and keyword search
with gr.Blocks() as iface:
    # Add a title at the top of the interface
    gr.HTML("<h1 style='text-align: center'>Byaldi + Qwen2VL</h1>")

    # Image upload and text extraction section
    with gr.Column():
        img_input = gr.Image(type="pil", label="Upload an Image")
        extracted_output = gr.Textbox(label="Extracted Text", interactive=False)  # Use Textbox to store text

        # Functionality to trigger the OCR and extraction
        img_button = gr.Button("Extract Text")
        img_button.click(fn=ocr_and_extract, inputs=img_input, outputs=extracted_output)

    # Keyword search section
    with gr.Column():
        search_input = gr.Textbox(label="Enter keywords to search")
        search_output = gr.HTML(label="Search Results")

        # Functionality to search within the extracted text
        search_button = gr.Button("Search")
        search_button.click(fn=search_keywords, inputs=[extracted_output, search_input], outputs=search_output)

iface.launch()