prithivMLmods commited on
Commit
c307af6
·
verified ·
1 Parent(s): ce03905

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -102
app.py CHANGED
@@ -14,9 +14,7 @@ from transformers import (
14
  )
15
  from transformers import Qwen2_5_VLForConditionalGeneration
16
 
17
- # ---------------------------
18
  # Helper Functions
19
- # ---------------------------
20
  def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
21
  """
22
  Returns an HTML snippet for a thin animated progress bar with a label.
@@ -49,7 +47,6 @@ def downsample_video(video_path):
49
  if total_frames <= 0 or fps <= 0:
50
  vidcap.release()
51
  return frames
52
- # Determine 10 evenly spaced frame indices.
53
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
54
  for i in frame_indices:
55
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
@@ -63,8 +60,7 @@ def downsample_video(video_path):
63
  return frames
64
 
65
  # Model and Processor Setup
66
- # Qwen2VL OCR (default branch)
67
- QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" # [or] prithivMLmods/Qwen2-VL-OCR2-2B-Instruct
68
  qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
69
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
70
  QV_MODEL_ID,
@@ -72,7 +68,6 @@ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
72
  torch_dtype=torch.float16
73
  ).to("cuda").eval()
74
 
75
- # RolmOCR branch (@RolmOCR)
76
  ROLMOCR_MODEL_ID = "reducto/RolmOCR"
77
  rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True)
78
  rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -83,111 +78,62 @@ rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
83
 
84
  # Main Inference Function
85
  @spaces.GPU
86
- def model_inference(input_dict, history):
87
  text = input_dict["text"].strip()
88
  files = input_dict.get("files", [])
89
 
90
- # RolmOCR Inference (@RolmOCR)
91
- if text.lower().startswith("@rolmocr"):
92
- # Remove the tag from the query.
93
- text_prompt = text[len("@rolmocr"):].strip()
94
- # Check if a video is provided for inference.
95
- if files and isinstance(files[0], str) and files[0].lower().endswith((".mp4", ".avi", ".mov")):
96
- video_path = files[0]
97
- frames = downsample_video(video_path)
 
98
  if not frames:
99
  yield "Error: Could not extract frames from the video."
100
  return
101
- # Build the message: prompt followed by each frame with its timestamp.
102
- content_list = [{"type": "text", "text": text_prompt}]
103
- for image, timestamp in frames:
104
- content_list.append({"type": "text", "text": f"Frame {timestamp}:"})
105
- content_list.append({"type": "image", "image": image})
106
- messages = [{"role": "user", "content": content_list}]
107
- # For video, extract images only.
108
- video_images = [image for image, _ in frames]
109
- prompt_full = rolmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
110
- inputs = rolmocr_processor(
111
- text=[prompt_full],
112
- images=video_images,
113
- return_tensors="pt",
114
- padding=True,
115
- ).to("cuda")
116
  else:
117
- # Assume image(s) or text query.
118
- if len(files) > 1:
119
- images = [load_image(image) for image in files]
120
- elif len(files) == 1:
121
- images = [load_image(files[0])]
122
- else:
123
- images = []
124
- if text_prompt == "" and not images:
125
- yield "Error: Please input a text query and/or provide an image for the @RolmOCR feature."
126
  return
127
- messages = [{
128
- "role": "user",
129
- "content": [
130
- *[{"type": "image", "image": image} for image in images],
131
- {"type": "text", "text": text_prompt},
132
- ],
133
- }]
134
- prompt_full = rolmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
135
- inputs = rolmocr_processor(
136
- text=[prompt_full],
137
- images=images if images else None,
138
- return_tensors="pt",
139
- padding=True,
140
- ).to("cuda")
141
- streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
142
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
143
- thread = Thread(target=rolmocr_model.generate, kwargs=generation_kwargs)
144
- thread.start()
145
- buffer = ""
146
- # Use a different color scheme for RolmOCR (purple-themed).
147
- yield progress_bar_html("Processing with Qwen2.5VL (RolmOCR)")
148
- for new_text in streamer:
149
- buffer += new_text
150
- buffer = buffer.replace("<|im_end|>", "")
151
- time.sleep(0.01)
152
- yield buffer
153
- return
154
 
155
- # Default Inference: Qwen2VL OCR
156
- # Process files: support multiple images.
157
- if len(files) > 1:
158
- images = [load_image(image) for image in files]
159
- elif len(files) == 1:
160
- images = [load_image(files[0])]
161
- else:
162
- images = []
163
-
164
- if text == "" and not images:
165
- yield "Error: Please input a text query and optionally image(s)."
166
- return
167
- if text == "" and images:
168
- yield "Error: Please input a text query along with the image(s)."
169
- return
170
 
171
- messages = [{
172
- "role": "user",
173
- "content": [
174
- *[{"type": "image", "image": image} for image in images],
175
- {"type": "text", "text": text},
176
- ],
177
- }]
178
- prompt_full = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
179
- inputs = qwen_processor(
180
  text=[prompt_full],
181
- images=images if images else None,
182
  return_tensors="pt",
183
  padding=True,
184
  ).to("cuda")
185
- streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
 
186
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
187
- thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
188
  thread.start()
189
  buffer = ""
190
- yield progress_bar_html("Processing with Qwen2VL OCR")
191
  for new_text in streamer:
192
  buffer += new_text
193
  buffer = buffer.replace("<|im_end|>", "")
@@ -196,25 +142,26 @@ def model_inference(input_dict, history):
196
 
197
  # Gradio Interface
198
  examples = [
199
- [{"text": "@RolmOCR OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
200
- [{"text": "@RolmOCR Explain the Ad in Detail", "files": ["examples/videoplayback.mp4"]}],
201
- [{"text": "@RolmOCR OCR the Image", "files": ["rolm/3.jpeg"]}],
202
  [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
203
  ]
204
 
205
  demo = gr.ChatInterface(
206
  fn=model_inference,
207
- description="# **Multimodal OCR `@RolmOCR and Default Qwen2VL OCR`**",
208
  examples=examples,
209
  textbox=gr.MultimodalTextbox(
210
- label="Query Input",
211
- file_types=["image", "video"],
212
- file_count="multiple",
213
- placeholder="Use tag @RolmOCR for RolmOCR, or leave blank for default Qwen2VL OCR"
214
  ),
215
  stop_btn="Stop Generation",
216
  multimodal=True,
217
  cache_examples=False,
 
218
  )
219
 
220
  demo.launch(debug=True)
 
14
  )
15
  from transformers import Qwen2_5_VLForConditionalGeneration
16
 
 
17
  # Helper Functions
 
18
  def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
19
  """
20
  Returns an HTML snippet for a thin animated progress bar with a label.
 
47
  if total_frames <= 0 or fps <= 0:
48
  vidcap.release()
49
  return frames
 
50
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
51
  for i in frame_indices:
52
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
 
60
  return frames
61
 
62
  # Model and Processor Setup
63
+ QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
 
64
  qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
65
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
66
  QV_MODEL_ID,
 
68
  torch_dtype=torch.float16
69
  ).to("cuda").eval()
70
 
 
71
  ROLMOCR_MODEL_ID = "reducto/RolmOCR"
72
  rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True)
73
  rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
78
 
79
  # Main Inference Function
80
  @spaces.GPU
81
+ def model_inference(input_dict, history, use_rolmocr=False):
82
  text = input_dict["text"].strip()
83
  files = input_dict.get("files", [])
84
 
85
+ if not text and not files:
86
+ yield "Error: Please input a text query or provide files (images or videos)."
87
+ return
88
+
89
+ # Process files: images and videos
90
+ image_list = []
91
+ for idx, file in enumerate(files):
92
+ if file.lower().endswith((".mp4", ".avi", ".mov")):
93
+ frames = downsample_video(file)
94
  if not frames:
95
  yield "Error: Could not extract frames from the video."
96
  return
97
+ for frame, timestamp in frames:
98
+ label = f"Video {idx+1} Frame {timestamp}:"
99
+ image_list.append((label, frame))
 
 
 
 
 
 
 
 
 
 
 
 
100
  else:
101
+ try:
102
+ img = load_image(file)
103
+ label = f"Image {idx+1}:"
104
+ image_list.append((label, img))
105
+ except Exception as e:
106
+ yield f"Error loading image: {str(e)}"
 
 
 
107
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ # Build content list
110
+ content = [{"type": "text", "text": text}]
111
+ for label, img in image_list:
112
+ content.append({"type": "text", "text": label})
113
+ content.append({"type": "image", "image": img})
114
+
115
+ messages = [{"role": "user", "content": content}]
 
 
 
 
 
 
 
 
116
 
117
+ # Select processor and model
118
+ processor = rolmocr_processor if use_rolmocr else qwen_processor
119
+ model = rolmocr_model if use_rolmocr else qwen_model
120
+ model_name = "RolmOCR" if use_rolmocr else "Qwen2VL OCR"
121
+
122
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
123
+ all_images = [item["image"] for item in content if item["type"] == "image"]
124
+ inputs = processor(
 
125
  text=[prompt_full],
126
+ images=all_images if all_images else None,
127
  return_tensors="pt",
128
  padding=True,
129
  ).to("cuda")
130
+
131
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
132
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
133
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
134
  thread.start()
135
  buffer = ""
136
+ yield progress_bar_html(f"Processing with {model_name}")
137
  for new_text in streamer:
138
  buffer += new_text
139
  buffer = buffer.replace("<|im_end|>", "")
 
142
 
143
  # Gradio Interface
144
  examples = [
145
+ [{"text": "OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
146
+ [{"text": "Explain the Ad in Detail", "files": ["examples/videoplayback.mp4"]}],
147
+ [{"text": "OCR the Image", "files": ["rolm/3.jpeg"]}],
148
  [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
149
  ]
150
 
151
  demo = gr.ChatInterface(
152
  fn=model_inference,
153
+ description="# **Multimodal OCR with Model Selection**",
154
  examples=examples,
155
  textbox=gr.MultimodalTextbox(
156
+ label="Query Input",
157
+ file_types=["image", "video"],
158
+ file_count="multiple",
159
+ placeholder="Input your query and optionally upload image(s) or video(s). Select the model using the checkbox."
160
  ),
161
  stop_btn="Stop Generation",
162
  multimodal=True,
163
  cache_examples=False,
164
+ additional_inputs=[gr.Checkbox(label="Use RolmOCR", value=False, info="Check to use RolmOCR, uncheck to use Qwen2VL OCR")],
165
  )
166
 
167
  demo.launch(debug=True)