DeF0017 commited on
Commit
b49897a
1 Parent(s): dab1ed0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
4
+ from qwen_vl_utils import process_vision_info
5
+ import torch
6
+ from PIL import Image
7
+ import subprocess
8
+ import numpy as np
9
+ import os
10
+ from threading import Thread
11
+ import uuid
12
+ import io
13
+ import re # Import regular expressions for word highlighting
14
+
15
+ # Model and Processor Loading (Done once at startup)
16
+ MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
17
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
18
+ MODEL_ID,
19
+ trust_remote_code=True,
20
+ torch_dtype=torch.float16
21
+ ).to("cuda").eval()
22
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
23
+
24
+ DESCRIPTION = "[Qwen2-VL-2B Demo](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)"
25
+
26
+ # Define supported media extensions
27
+ image_extensions = Image.registered_extensions()
28
+ video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
29
+
30
+
31
+ def identify_and_save_blob(blob_path):
32
+ """Identifies if the blob is an image or video and saves it accordingly."""
33
+ try:
34
+ with open(blob_path, 'rb') as file:
35
+ blob_content = file.read()
36
+
37
+ # Try to identify if it's an image
38
+ try:
39
+ Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
40
+ extension = ".png" # Default to PNG for saving
41
+ media_type = "image"
42
+ except (IOError, SyntaxError):
43
+ # If it's not a valid image, assume it's a video
44
+ extension = ".mp4" # Default to MP4 for saving
45
+ media_type = "video"
46
+
47
+ # Create a unique filename
48
+ filename = f"temp_{uuid.uuid4()}_media{extension}"
49
+ with open(filename, "wb") as f:
50
+ f.write(blob_content)
51
+
52
+ return filename, media_type
53
+
54
+ except FileNotFoundError:
55
+ raise ValueError(f"The file {blob_path} was not found.")
56
+ except Exception as e:
57
+ raise ValueError(f"An error occurred while processing the file: {e}")
58
+
59
+
60
+ @spaces.GPU
61
+ def qwen_inference(media_input, search_word):
62
+ """
63
+ Performs OCR on the input media and highlights the search_word in the extracted text.
64
+
65
+ Args:
66
+ media_input (str): Path to the uploaded image or video file.
67
+ search_word (str): The word to search and highlight in the OCR result.
68
+
69
+ Yields:
70
+ str: The OCR result with highlighted search words.
71
+ """
72
+ text_input = "Extract text" # Hardcoded text query
73
+
74
+ if isinstance(media_input, str): # If it's a filepath
75
+ media_path = media_input
76
+ if media_path.endswith(tuple([i for i, f in image_extensions.items()])):
77
+ media_type = "image"
78
+ elif media_path.endswith(video_extensions):
79
+ media_type = "video"
80
+ else:
81
+ try:
82
+ media_path, media_type = identify_and_save_blob(media_input)
83
+ print(media_path, media_type)
84
+ except Exception as e:
85
+ print(e)
86
+ raise ValueError(
87
+ "Unsupported media type. Please upload an image or video."
88
+ )
89
+
90
+ print(f"Processing media: {media_path} (Type: {media_type})")
91
+
92
+ messages = [
93
+ {
94
+ "role": "user",
95
+ "content": [
96
+ {
97
+ "type": media_type,
98
+ media_type: media_path,
99
+ **({"fps": 8.0} if media_type == "video" else {}),
100
+ },
101
+ {"type": "text", "text": text_input},
102
+ ],
103
+ }
104
+ ]
105
+
106
+ # Apply chat template to format the input for the model
107
+ text = processor.apply_chat_template(
108
+ messages, tokenize=False, add_generation_prompt=True
109
+ )
110
+ image_inputs, video_inputs = process_vision_info(messages)
111
+
112
+ # Prepare model inputs
113
+ inputs = processor(
114
+ text=[text],
115
+ images=image_inputs,
116
+ videos=video_inputs,
117
+ padding=True,
118
+ return_tensors="pt",
119
+ ).to("cuda")
120
+
121
+ # Initialize the streamer for iterative generation
122
+ streamer = TextIteratorStreamer(
123
+ processor, skip_prompt=True, **{"skip_special_tokens": True}
124
+ )
125
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
126
+
127
+ # Start the generation in a separate thread
128
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
129
+ thread.start()
130
+
131
+ buffer = ""
132
+ for new_text in streamer:
133
+ buffer += new_text
134
+ # Highlight the search_word in the buffer
135
+ if search_word:
136
+ # Use regex for case-insensitive search and highlight
137
+ pattern = re.compile(re.escape(search_word), re.IGNORECASE)
138
+ highlighted_text = pattern.sub(lambda m: f"<mark>{m.group(0)}</mark>", buffer)
139
+ else:
140
+ highlighted_text = buffer
141
+ yield highlighted_text
142
+
143
+
144
+ css = """
145
+ #output {
146
+ height: 500px;
147
+ overflow: auto;
148
+ border: 1px solid #ccc;
149
+ }
150
+ """
151
+
152
+ with gr.Blocks(css=css) as demo:
153
+ gr.Markdown(DESCRIPTION)
154
+
155
+ with gr.Tab(label="Image/Video Input"):
156
+ with gr.Row():
157
+ with gr.Column():
158
+ input_media = gr.File(
159
+ label="Upload Image or Video", type="filepath"
160
+ )
161
+ search_word = gr.Textbox(
162
+ label="Search Word", placeholder="Enter word to highlight", lines=1
163
+ )
164
+ submit_btn = gr.Button(value="Submit")
165
+ with gr.Column():
166
+ # Use HTML component to display highlighted text
167
+ output_text = gr.HTML(label="Output Text")
168
+
169
+ submit_btn.click(
170
+ qwen_inference,
171
+ inputs=[input_media, search_word],
172
+ outputs=[output_text]
173
+ )
174
+
175
+ demo.launch(debug=True)