update app [Bounding Boxes]

#6
Files changed (1) hide show
  1. app.py +92 -57
app.py CHANGED
@@ -15,7 +15,7 @@ import tempfile
15
  import gradio as gr
16
  import requests
17
  import torch
18
- from PIL import Image
19
  import fitz
20
  import numpy as np
21
 
@@ -130,7 +130,7 @@ def generate_and_preview_pdf(image: Image.Image, text_content: str, font_size: i
130
  def process_document_stream(
131
  image: Image.Image,
132
  prompt_input: str,
133
- image_scale_factor: float, # New parameter for image scaling
134
  max_new_tokens: int,
135
  temperature: float,
136
  top_p: float,
@@ -138,7 +138,7 @@ def process_document_stream(
138
  repetition_penalty: float
139
  ):
140
  """
141
- Main function that handles model inference using tencent/POINTS-Reader.
142
  """
143
  if image is None:
144
  yield "Please upload an image.", ""
@@ -147,42 +147,29 @@ def process_document_stream(
147
  yield "Please enter a prompt.", ""
148
  return
149
 
150
- # --- IMPLEMENTATION: Image Scaling based on user input ---
151
  if image_scale_factor > 1.0:
152
  try:
153
  original_width, original_height = image.size
154
  new_width = int(original_width * image_scale_factor)
155
  new_height = int(original_height * image_scale_factor)
156
  print(f"Scaling image from {image.size} to ({new_width}, {new_height}) with factor {image_scale_factor}.")
157
- # Use a high-quality resampling filter for better results
158
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
159
  except Exception as e:
160
  print(f"Error during image scaling: {e}")
161
- # Continue with the original image if scaling fails
162
  pass
163
- # --- END IMPLEMENTATION ---
164
 
165
  temp_image_path = None
166
  try:
167
- # --- FIX: Save the PIL Image to a temporary file ---
168
- # The model expects a file path, not a PIL object.
169
  temp_dir = tempfile.gettempdir()
170
  temp_image_path = os.path.join(temp_dir, f"temp_image_{uuid.uuid4()}.png")
171
  image.save(temp_image_path)
172
 
173
- # Prepare content for the model using the temporary file path
174
  content = [
175
  dict(type='image', image=temp_image_path),
176
  dict(type='text', text=prompt_input)
177
  ]
178
- messages = [
179
- {
180
- 'role': 'user',
181
- 'content': content
182
- }
183
- ]
184
 
185
- # Prepare generation configuration from UI inputs
186
  generation_config = {
187
  'max_new_tokens': max_new_tokens,
188
  'repetition_penalty': repetition_penalty,
@@ -192,21 +179,78 @@ def process_document_stream(
192
  'do_sample': True if temperature > 0 else False
193
  }
194
 
195
- # Run inference
196
- response = model.chat(
197
- messages,
198
- tokenizer,
199
- image_processor,
200
- generation_config
201
- )
202
- # Yield the full response at once
203
  yield response, response
204
 
205
  except Exception as e:
206
  traceback.print_exc()
207
  yield f"An error occurred during processing: {str(e)}", ""
208
  finally:
209
- # --- Clean up the temporary image file ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  if temp_image_path and os.path.exists(temp_image_path):
211
  os.remove(temp_image_path)
212
 
@@ -233,29 +277,12 @@ def create_gradio_interface():
233
  with gr.Row():
234
  # Left Column (Inputs)
235
  with gr.Column(scale=1):
236
- gr.Textbox(
237
- label="Model in Use ",
238
- value="tencent/POINTS-Reader",
239
- interactive=False
240
- )
241
- prompt_input = gr.Textbox(
242
- label="Query Input",
243
- placeholder="✦︎ Enter the prompt",
244
- value="Perform OCR on the image precisely.",
245
- )
246
  image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
247
 
248
  with gr.Accordion("Advanced Settings", open=False):
249
- # --- NEW UI ELEMENT: Image Scaling Slider ---
250
- image_scale_factor = gr.Slider(
251
- minimum=1.0,
252
- maximum=3.0,
253
- value=1.0,
254
- step=0.1,
255
- label="Image Upscale Factor",
256
- info="Increases image size before processing. Can improve OCR on small text. Default: 1.0 (no change)."
257
- )
258
- # --- END NEW UI ELEMENT ---
259
  max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=2048, step=256, label="Max New Tokens")
260
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.05, value=0.7)
261
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.8)
@@ -277,19 +304,21 @@ def create_gradio_interface():
277
  with gr.Tab("📝 Extracted Content"):
278
  raw_output_stream = gr.Textbox(label="Raw Model Output (max T ≤ 120s)", interactive=False, lines=15, show_copy_button=True)
279
  with gr.Row():
280
- examples = gr.Examples(
281
- examples=["examples/1.jpeg",
282
- "examples/2.jpeg",
283
- "examples/3.jpeg",
284
- "examples/4.jpeg",
285
- "examples/5.jpeg"],
286
- inputs=image_input, label="Examples"
287
- )
288
  gr.Markdown("[Report-Bug💻](https://huggingface.co/spaces/prithivMLmods/POINTS-Reader-OCR/discussions) | [prithivMLmods🤗](https://huggingface.co/prithivMLmods)")
289
 
290
  with gr.Tab("📰 README.md"):
291
  with gr.Accordion("(Result.md)", open=True):
292
  markdown_output = gr.Markdown()
 
 
 
 
 
 
 
 
 
293
 
294
  with gr.Tab("📋 PDF Preview"):
295
  generate_pdf_btn = gr.Button("📄 Generate PDF & Render", variant="primary")
@@ -298,15 +327,21 @@ def create_gradio_interface():
298
 
299
  # Event Handlers
300
  def clear_all_outputs():
301
- return None, "", "Raw output will appear here.", "", None, None
 
302
 
303
  process_btn.click(
304
  fn=process_document_stream,
305
- # --- UPDATE: Add the new slider to the inputs list ---
306
  inputs=[image_input, prompt_input, image_scale_factor, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
307
  outputs=[raw_output_stream, markdown_output]
308
  )
309
 
 
 
 
 
 
 
310
  generate_pdf_btn.click(
311
  fn=generate_and_preview_pdf,
312
  inputs=[image_input, raw_output_stream, font_size, line_spacing, alignment, image_size],
@@ -315,7 +350,7 @@ def create_gradio_interface():
315
 
316
  clear_btn.click(
317
  clear_all_outputs,
318
- outputs=[image_input, prompt_input, raw_output_stream, markdown_output, pdf_output_file, pdf_preview_gallery]
319
  )
320
  return demo
321
 
 
15
  import gradio as gr
16
  import requests
17
  import torch
18
+ from PIL import Image, ImageDraw
19
  import fitz
20
  import numpy as np
21
 
 
130
  def process_document_stream(
131
  image: Image.Image,
132
  prompt_input: str,
133
+ image_scale_factor: float,
134
  max_new_tokens: int,
135
  temperature: float,
136
  top_p: float,
 
138
  repetition_penalty: float
139
  ):
140
  """
141
+ Main function that handles model inference for general OCR.
142
  """
143
  if image is None:
144
  yield "Please upload an image.", ""
 
147
  yield "Please enter a prompt.", ""
148
  return
149
 
 
150
  if image_scale_factor > 1.0:
151
  try:
152
  original_width, original_height = image.size
153
  new_width = int(original_width * image_scale_factor)
154
  new_height = int(original_height * image_scale_factor)
155
  print(f"Scaling image from {image.size} to ({new_width}, {new_height}) with factor {image_scale_factor}.")
 
156
  image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
157
  except Exception as e:
158
  print(f"Error during image scaling: {e}")
 
159
  pass
 
160
 
161
  temp_image_path = None
162
  try:
 
 
163
  temp_dir = tempfile.gettempdir()
164
  temp_image_path = os.path.join(temp_dir, f"temp_image_{uuid.uuid4()}.png")
165
  image.save(temp_image_path)
166
 
 
167
  content = [
168
  dict(type='image', image=temp_image_path),
169
  dict(type='text', text=prompt_input)
170
  ]
171
+ messages = [{'role': 'user', 'content': content}]
 
 
 
 
 
172
 
 
173
  generation_config = {
174
  'max_new_tokens': max_new_tokens,
175
  'repetition_penalty': repetition_penalty,
 
179
  'do_sample': True if temperature > 0 else False
180
  }
181
 
182
+ response = model.chat(messages, tokenizer, image_processor, generation_config)
 
 
 
 
 
 
 
183
  yield response, response
184
 
185
  except Exception as e:
186
  traceback.print_exc()
187
  yield f"An error occurred during processing: {str(e)}", ""
188
  finally:
189
+ if temp_image_path and os.path.exists(temp_image_path):
190
+ os.remove(temp_image_path)
191
+
192
+ # --- Bounding Box Extraction Logic ---
193
+ @spaces.GPU
194
+ def extract_text_with_coordinates(image: Image.Image):
195
+ """
196
+ Runs the model with a specific prompt to get OCR and bounding boxes,
197
+ then processes the output to create a visualization.
198
+ """
199
+ if image is None:
200
+ raise gr.Error("Please upload an image first in the main tab.")
201
+
202
+ prompt = "Please perform OCR on the image and provide the bounding box for each recognized text line. The format should be 'text<box>x1, y1, x2, y2</box>'."
203
+ temp_image_path = None
204
+ try:
205
+ temp_dir = tempfile.gettempdir()
206
+ temp_image_path = os.path.join(temp_dir, f"temp_image_{uuid.uuid4()}.png")
207
+ image.save(temp_image_path)
208
+
209
+ content = [dict(type='image', image=temp_image_path), dict(type='text', text=prompt)]
210
+ messages = [{'role': 'user', 'content': content}]
211
+ generation_config = {'max_new_tokens': 4096}
212
+
213
+ response = model.chat(messages, tokenizer, image_processor, generation_config)
214
+
215
+ original_width, original_height = image.size
216
+
217
+ # Regex to find coordinates inside <box> tags
218
+ pattern_coords = r"<box>(\d+,\s*\d+,\s*\d+,\s*\d+)</box>"
219
+ # Regex to split the string by the full box tag to isolate text
220
+ pattern_splitter = r"<box>\d+,\s*\d+,\s*\d+,\s*\d+</box>"
221
+
222
+ bboxs_raw = re.findall(pattern_coords, response)
223
+ lines = [line.strip() for line in re.split(pattern_splitter, response) if line.strip()]
224
+
225
+ num_items = min(len(lines), len(bboxs_raw))
226
+ vis_image = image.copy()
227
+ draw = ImageDraw.Draw(vis_image)
228
+ output_text = ""
229
+
230
+ for i in range(num_items):
231
+ line_text = lines[i]
232
+ box_coords = [int(c.strip()) for c in bboxs_raw[i].split(',')]
233
+
234
+ if len(box_coords) == 4:
235
+ x0, y0, x1, y1 = box_coords
236
+
237
+ # Scale coordinates from the model's 1000px basis to the original image size
238
+ x0_s = int(x0 * original_width / 1000)
239
+ y0_s = int(y0 * original_height / 1000)
240
+ x1_s = int(x1 * original_width / 1000)
241
+ y1_s = int(y1 * original_height / 1000)
242
+
243
+ draw.rectangle([x0_s, y0_s, x1_s, y1_s], outline="red", width=2)
244
+
245
+ # Format output as a polygon (quadrilateral) and the extracted text
246
+ output_text += f"{x0_s},{y0_s},{x1_s},{y0_s},{x1_s},{y1_s},{x0_s},{y1_s},{line_text}\n"
247
+
248
+ return output_text.strip(), vis_image
249
+
250
+ except Exception as e:
251
+ traceback.print_exc()
252
+ return f"An error occurred: {str(e)}", None
253
+ finally:
254
  if temp_image_path and os.path.exists(temp_image_path):
255
  os.remove(temp_image_path)
256
 
 
277
  with gr.Row():
278
  # Left Column (Inputs)
279
  with gr.Column(scale=1):
280
+ gr.Textbox(label="Model in Use ⚡", value="tencent/POINTS-Reader", interactive=False)
281
+ prompt_input = gr.Textbox(label="Query Input", placeholder="✦︎ Enter the prompt", value="Perform OCR on the image precisely.")
 
 
 
 
 
 
 
 
282
  image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
283
 
284
  with gr.Accordion("Advanced Settings", open=False):
285
+ image_scale_factor = gr.Slider(minimum=1.0, maximum=3.0, value=1.0, step=0.1, label="Image Upscale Factor", info="Increases image size before processing. Can improve OCR on small text.")
 
 
 
 
 
 
 
 
 
286
  max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=2048, step=256, label="Max New Tokens")
287
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.05, value=0.7)
288
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.8)
 
304
  with gr.Tab("📝 Extracted Content"):
305
  raw_output_stream = gr.Textbox(label="Raw Model Output (max T ≤ 120s)", interactive=False, lines=15, show_copy_button=True)
306
  with gr.Row():
307
+ examples = gr.Examples(examples=["examples/1.jpeg", "examples/2.jpeg", "examples/3.jpeg", "examples/4.jpeg", "examples/5.jpeg"], inputs=image_input, label="Examples")
 
 
 
 
 
 
 
308
  gr.Markdown("[Report-Bug💻](https://huggingface.co/spaces/prithivMLmods/POINTS-Reader-OCR/discussions) | [prithivMLmods🤗](https://huggingface.co/prithivMLmods)")
309
 
310
  with gr.Tab("📰 README.md"):
311
  with gr.Accordion("(Result.md)", open=True):
312
  markdown_output = gr.Markdown()
313
+
314
+ with gr.Tab("Bounding Boxes"):
315
+ gr.Markdown("Click the button to extract text and visualize its location on the image. This uses a specialized prompt to get coordinates from the model.")
316
+ with gr.Row():
317
+ with gr.Column(scale=1):
318
+ ocr_button = gr.Button("🔍 Extract Text with Coordinates", variant="primary")
319
+ ocr_text = gr.Textbox(label="Extracted Text with Coordinates", info="Format: x1,y1,x2,y2,x3,y3,x4,y4,text", lines=15, show_copy_button=True)
320
+ with gr.Column(scale=1):
321
+ ocr_vis = gr.Image(label="Visualization (Red boxes show detected text)")
322
 
323
  with gr.Tab("📋 PDF Preview"):
324
  generate_pdf_btn = gr.Button("📄 Generate PDF & Render", variant="primary")
 
327
 
328
  # Event Handlers
329
  def clear_all_outputs():
330
+ # Clear all input and output fields across all tabs
331
+ return None, "", "Raw output will appear here.", "", None, None, "", None
332
 
333
  process_btn.click(
334
  fn=process_document_stream,
 
335
  inputs=[image_input, prompt_input, image_scale_factor, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
336
  outputs=[raw_output_stream, markdown_output]
337
  )
338
 
339
+ ocr_button.click(
340
+ fn=extract_text_with_coordinates,
341
+ inputs=[image_input],
342
+ outputs=[ocr_text, ocr_vis]
343
+ )
344
+
345
  generate_pdf_btn.click(
346
  fn=generate_and_preview_pdf,
347
  inputs=[image_input, raw_output_stream, font_size, line_spacing, alignment, image_size],
 
350
 
351
  clear_btn.click(
352
  clear_all_outputs,
353
+ outputs=[image_input, prompt_input, raw_output_stream, markdown_output, pdf_output_file, pdf_preview_gallery, ocr_text, ocr_vis]
354
  )
355
  return demo
356