prithivMLmods commited on
Commit
647e0d3
·
verified ·
1 Parent(s): cd499b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -45
app.py CHANGED
@@ -1,15 +1,18 @@
1
  import os
2
  import sys
3
- import spaces
4
- from typing import Iterable
5
- import gradio as gr
6
- import torch
7
  import requests
 
 
8
  from PIL import Image, ImageDraw, ImageFont
 
 
9
  from transformers import AutoProcessor, Florence2ForConditionalGeneration
10
  from gradio.themes import Soft
11
  from gradio.themes.utils import colors, fonts, sizes
12
 
 
13
  colors.steel_blue = colors.Color(
14
  name="steel_blue",
15
  c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2",
@@ -69,6 +72,7 @@ css = """
69
  }
70
  """
71
 
 
72
  MODEL_IDS = {
73
  "Florence-2-base": "florence-community/Florence-2-base",
74
  "Florence-2-base-ft": "florence-community/Florence-2-base-ft",
@@ -76,15 +80,15 @@ MODEL_IDS = {
76
  "Florence-2-large-ft": "florence-community/Florence-2-large-ft",
77
  }
78
 
79
- models = {}
80
- processors = {}
81
 
82
  print("Loading Florence-2 models... This may take a while.")
83
  for name, repo_id in MODEL_IDS.items():
84
  print(f"Loading {name}...")
85
  model = Florence2ForConditionalGeneration.from_pretrained(
86
  repo_id,
87
- torch_dtype=torch.bfloat16,
88
  device_map="auto",
89
  trust_remote_code=True
90
  )
@@ -95,86 +99,201 @@ for name, repo_id in MODEL_IDS.items():
95
 
96
  print("\n🎉 All models loaded successfully!")
97
 
98
- def draw_bboxes(image, results, task):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  """
100
- Draws bounding boxes on the image based on the model's results.
 
 
101
  """
102
- if task not in results:
103
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- predictions = results[task]
106
- if not predictions:
107
- return image
108
 
109
- draw = ImageDraw.Draw(image)
 
 
 
 
 
 
110
  try:
111
- font = ImageFont.truetype("arial.ttf", 20)
112
- except IOError:
113
  font = ImageFont.load_default()
114
 
115
- for pred in predictions:
116
- if 'bboxes' in pred and 'labels' in pred:
117
- for box, label in zip(pred['bboxes'], pred['labels']):
118
- draw.rectangle(box, outline="red", width=3)
119
- text_position = [box[0], box[1] - 25]
120
- # Add a background to the text for better visibility
121
- text_bbox = draw.textbbox(tuple(text_position), label, font=font)
122
- draw.rectangle(text_bbox, fill="red")
123
- draw.text(tuple(text_position), label, fill="white", font=font)
124
- elif 'bboxes' in pred: # For tasks like REGION_PROPOSAL without labels
125
- for box in pred['bboxes']:
126
- draw.rectangle(box, outline="red", width=3)
 
 
 
 
 
 
 
 
127
 
128
- return image
129
 
 
 
 
130
 
131
- @spaces.GPU
132
  def run_florence2_inference(model_name: str, image: Image.Image, task_prompt: str,
133
  max_new_tokens: int = 1024, num_beams: int = 3):
134
  """
135
  Runs inference using the selected Florence-2 model.
 
136
  """
137
  if image is None:
138
- return "Please upload an image to get started.", None
139
 
140
  model = models[model_name]
141
  processor = processors[model_name]
142
 
143
- inputs = processor(text=task_prompt, images=image, return_tensors="pt").to(model.device, torch.bfloat16)
 
 
 
 
 
 
144
 
 
145
  generated_ids = model.generate(
146
- input_ids=inputs["input_ids"],
147
- pixel_values=inputs["pixel_values"],
148
  max_new_tokens=max_new_tokens,
149
  num_beams=num_beams,
150
  do_sample=False
151
  )
152
 
 
153
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
154
 
 
 
155
  image_size = image.size
156
  parsed_answer = processor.post_process_generation(
157
  generated_text, task=task_prompt, image_size=image_size
158
  )
159
 
160
- if task_prompt in ["<OD>", "<DENSE_REGION_CAPTION>", "<OCR_WITH_REGION>", "<REGION_PROPOSAL>"]:
161
- image_with_boxes = draw_bboxes(image.copy(), parsed_answer, task_prompt)
162
- return image_with_boxes, parsed_answer
163
- else:
164
- return image, parsed_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
 
 
166
 
 
167
  florence_tasks = [
168
  "<OD>", "<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>",
169
  "<DENSE_REGION_CAPTION>", "<REGION_PROPOSAL>", "<OCR>", "<OCR_WITH_REGION>"
170
  ]
171
 
 
172
  url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
173
  example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
174
 
175
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
176
  gr.Markdown("# **Florence-2 Vision Models**", elem_id="main-title")
177
- gr.Markdown("Select a model, upload an image, choose a task, and click Submit to see the results.")
178
 
179
  with gr.Row():
180
  with gr.Column(scale=2):
@@ -182,7 +301,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
182
  task_prompt = gr.Dropdown(
183
  label="Select Task",
184
  choices=florence_tasks,
185
- value="<OD>"
186
  )
187
  model_choice = gr.Radio(
188
  choices=list(MODEL_IDS.keys()),
@@ -202,13 +321,13 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
202
  with gr.Column(scale=3):
203
  gr.Markdown("## Output", elem_id="output-title")
204
  parsed_output = gr.JSON(label="Parsed Answer")
205
- output_image = gr.Image(label="Image with Bounding Boxes", type="pil")
206
 
207
  image_submit.click(
208
  fn=run_florence2_inference,
209
  inputs=[model_choice, image_upload, task_prompt, max_new_tokens, num_beams],
210
- outputs=[output_image, parsed_output]
211
  )
212
 
213
  if __name__ == "__main__":
214
- demo.queue().launch(debug=True, mcp_server=True, ssr_mode=False, show_error=True)
 
1
  import os
2
  import sys
3
+ import io
4
+ import json
 
 
5
  import requests
6
+ from typing import Iterable, List, Tuple, Dict, Any
7
+
8
  from PIL import Image, ImageDraw, ImageFont
9
+ import gradio as gr
10
+ import torch
11
  from transformers import AutoProcessor, Florence2ForConditionalGeneration
12
  from gradio.themes import Soft
13
  from gradio.themes.utils import colors, fonts, sizes
14
 
15
+ # ---------- Theme (kept from your original) ----------
16
  colors.steel_blue = colors.Color(
17
  name="steel_blue",
18
  c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2",
 
72
  }
73
  """
74
 
75
+ # ---------- Models ----------
76
  MODEL_IDS = {
77
  "Florence-2-base": "florence-community/Florence-2-base",
78
  "Florence-2-base-ft": "florence-community/Florence-2-base-ft",
 
80
  "Florence-2-large-ft": "florence-community/Florence-2-large-ft",
81
  }
82
 
83
+ models: Dict[str, Florence2ForConditionalGeneration] = {}
84
+ processors: Dict[str, AutoProcessor] = {}
85
 
86
  print("Loading Florence-2 models... This may take a while.")
87
  for name, repo_id in MODEL_IDS.items():
88
  print(f"Loading {name}...")
89
  model = Florence2ForConditionalGeneration.from_pretrained(
90
  repo_id,
91
+ dtype=torch.bfloat16,
92
  device_map="auto",
93
  trust_remote_code=True
94
  )
 
99
 
100
  print("\n🎉 All models loaded successfully!")
101
 
102
+ # ---------- Utilities ----------
103
+ def _safe_parse_json_like(text: Any) -> Any:
104
+ """
105
+ If text is a dict already, return it. If it's a JSON-like string, try to json.loads it.
106
+ Otherwise return the original text.
107
+ """
108
+ if isinstance(text, dict):
109
+ return text
110
+ if isinstance(text, str):
111
+ text_str = text.strip()
112
+ # try to decode if it looks like JSON
113
+ if (text_str.startswith("{") and text_str.endswith("}")) or (text_str.startswith("[") and text_str.endswith("]")):
114
+ try:
115
+ return json.loads(text_str)
116
+ except Exception:
117
+ # fallback to returning original string
118
+ return text
119
+ return text
120
+
121
+ def _find_bboxes_and_labels(obj: Any) -> List[Tuple[List[int], str]]:
122
  """
123
+ Recursively search `obj` (dict/list) for pairs of 'bboxes' and 'labels' (or region entries).
124
+ Returns list of tuples: (bbox, label)
125
+ bbox assumed as [x1,y1,x2,y2] (integers/floats)
126
  """
127
+ found: List[Tuple[List[int], str]] = []
128
+
129
+ def recurse(o: Any):
130
+ if isinstance(o, dict):
131
+ # direct pair case
132
+ if "bboxes" in o:
133
+ bboxes = o.get("bboxes", [])
134
+ labels = o.get("labels", [])
135
+ # if labels length mismatch, fill with empty strings
136
+ for i, bx in enumerate(bboxes):
137
+ lbl = labels[i] if i < len(labels) else ""
138
+ # sometimes bboxes come as dicts with keys or lists
139
+ if isinstance(bx, dict) and {"x","y","w","h"}.issubset(bx.keys()):
140
+ # convert xywh to x1,y1,x2,y2
141
+ x = bx["x"]; y = bx["y"]; w = bx["w"]; h = bx["h"]
142
+ found.append(([int(x), int(y), int(x + w), int(y + h)], lbl))
143
+ else:
144
+ # assume list-like [x1,y1,x2,y2] or [x,y,w,h]
145
+ try:
146
+ bx_list = list(map(int, bx))
147
+ if len(bx_list) == 4:
148
+ x1, y1, x2, y2 = bx_list
149
+ # Heuristic: if x2>x1 and y2>y1 assume x1,y1,x2,y2 otherwise maybe xywh
150
+ if x2 > x1 and y2 > y1:
151
+ found.append(([x1, y1, x2, y2], lbl))
152
+ else:
153
+ # try treat as xywh
154
+ found.append(([x1, y1, x1 + x2, y1 + y2], lbl))
155
+ else:
156
+ # skip unexpected format
157
+ pass
158
+ except Exception:
159
+ pass
160
+ # also check for region entries like {'bbox': ..., 'text': ...} or list of regions
161
+ if "regions" in o and isinstance(o["regions"], list):
162
+ for reg in o["regions"]:
163
+ if isinstance(reg, dict) and "bbox" in reg:
164
+ bx = reg["bbox"]
165
+ lbl = reg.get("label", reg.get("text", ""))
166
+ try:
167
+ bx_list = list(map(int, bx))
168
+ if len(bx_list) == 4:
169
+ found.append(([bx_list[0], bx_list[1], bx_list[2], bx_list[3]], lbl))
170
+ except Exception:
171
+ pass
172
+ # recurse deeper
173
+ for v in o.values():
174
+ recurse(v)
175
+ elif isinstance(o, list):
176
+ for item in o:
177
+ recurse(item)
178
+ # else ignore primitives
179
 
180
+ recurse(obj)
181
+ return found
 
182
 
183
+ def _draw_bboxes_on_image(img: Image.Image, boxes_and_labels: List[Tuple[List[int], str]]) -> Image.Image:
184
+ """
185
+ Draw bounding boxes and labels on a copy of `img`.
186
+ """
187
+ annotated = img.convert("RGB").copy()
188
+ draw = ImageDraw.Draw(annotated)
189
+ # try to get a default font (PIL may not have a TTF available)
190
  try:
191
+ font = ImageFont.truetype("DejaVuSans.ttf", size=14)
192
+ except Exception:
193
  font = ImageFont.load_default()
194
 
195
+ for bbox, label in boxes_and_labels:
196
+ # bbox should be [x1,y1,x2,y2]
197
+ x1, y1, x2, y2 = bbox
198
+ # keep coordinates within image bounds
199
+ x1 = max(0, int(x1)); y1 = max(0, int(y1))
200
+ x2 = min(annotated.width - 1, int(x2)); y2 = min(annotated.height - 1, int(y2))
201
+ # draw rectangle (thicker by drawing several offsets)
202
+ thickness = max(2, int(round(min(annotated.width, annotated.height) / 200)))
203
+ for t in range(thickness):
204
+ draw.rectangle([x1 - t, y1 - t, x2 + t, y2 + t], outline="red")
205
+ # draw label background
206
+ if label is None:
207
+ label = ""
208
+ label_text = str(label)
209
+ text_w, text_h = draw.textsize(label_text, font=font)
210
+ # background rectangle for label (semi-opaque)
211
+ label_bg = [x1, max(0, y1 - text_h - 4), x1 + text_w + 6, y1]
212
+ draw.rectangle(label_bg, fill="red")
213
+ # text
214
+ draw.text((x1 + 3, max(0, y1 - text_h - 2)), label_text, fill="white", font=font)
215
 
216
+ return annotated
217
 
218
+ # ---------- Inference function ----------
219
+ # tasks for which we attempt to extract/display bboxes
220
+ VISUAL_REGION_TASKS = {"<OD>", "<DENSE_REGION_CAPTION>", "<OCR_WITH_REGION>", "<REGION_PROPOSAL>"}
221
 
222
+ # If you are using Spaces with GPU decorator; keep it as-is in your environment
223
  def run_florence2_inference(model_name: str, image: Image.Image, task_prompt: str,
224
  max_new_tokens: int = 1024, num_beams: int = 3):
225
  """
226
  Runs inference using the selected Florence-2 model.
227
+ Returns a tuple: (parsed_answer, annotated_image_or_none)
228
  """
229
  if image is None:
230
+ return {"error": "Please upload an image to get started."}, None
231
 
232
  model = models[model_name]
233
  processor = processors[model_name]
234
 
235
+ # Prepare inputs (move to model device)
236
+ inputs = processor(text=task_prompt, images=image, return_tensors="pt")
237
+ # send tensors to model device and set dtype
238
+ device = model.device
239
+ for k, v in inputs.items():
240
+ if isinstance(v, torch.Tensor):
241
+ inputs[k] = v.to(device, dtype=torch.bfloat16)
242
 
243
+ # Generate
244
  generated_ids = model.generate(
245
+ input_ids=inputs.get("input_ids"),
246
+ pixel_values=inputs.get("pixel_values"),
247
  max_new_tokens=max_new_tokens,
248
  num_beams=num_beams,
249
  do_sample=False
250
  )
251
 
252
+ # Decode
253
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
254
 
255
+ # Post-process (the processor provided by Florence models sometimes provides
256
+ # a structured output such as dict with 'bboxes' etc.)
257
  image_size = image.size
258
  parsed_answer = processor.post_process_generation(
259
  generated_text, task=task_prompt, image_size=image_size
260
  )
261
 
262
+ # Try to make parsed_answer JSON-serializable and easily inspectable
263
+ parsed_serializable = parsed_answer
264
+ # If it's a string that contains JSON, attempt to parse
265
+ if isinstance(parsed_answer, str):
266
+ parsed_serializable = _safe_parse_json_like(parsed_answer)
267
+
268
+ annotated_image = None
269
+ # If the task is in our visual region tasks, try to find bboxes and labels
270
+ if task_prompt in VISUAL_REGION_TASKS:
271
+ # parsed_serializable may be dict/list or string; try to find bboxes
272
+ boxes_and_labels = _find_bboxes_and_labels(parsed_serializable)
273
+ if boxes_and_labels:
274
+ try:
275
+ annotated_image = _draw_bboxes_on_image(image, boxes_and_labels)
276
+ except Exception as e:
277
+ # if drawing fails, set annotated_image to None but keep parsed answer
278
+ print("Failed to draw boxes:", e)
279
+ annotated_image = None
280
 
281
+ # Return parsed answer (prefer a dict or serializable structure) and annotated image (PIL) or None
282
+ return parsed_serializable, annotated_image
283
 
284
+ # ---------- UI ----------
285
  florence_tasks = [
286
  "<OD>", "<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>",
287
  "<DENSE_REGION_CAPTION>", "<REGION_PROPOSAL>", "<OCR>", "<OCR_WITH_REGION>"
288
  ]
289
 
290
+ # Example image (keeps your example)
291
  url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
292
  example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
293
 
294
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
295
  gr.Markdown("# **Florence-2 Vision Models**", elem_id="main-title")
296
+ gr.Markdown("Select a model, upload an image, choose a task, and click Submit to see the parsed output and an annotated image (when bounding boxes are present).")
297
 
298
  with gr.Row():
299
  with gr.Column(scale=2):
 
301
  task_prompt = gr.Dropdown(
302
  label="Select Task",
303
  choices=florence_tasks,
304
+ value="<MORE_DETAILED_CAPTION>"
305
  )
306
  model_choice = gr.Radio(
307
  choices=list(MODEL_IDS.keys()),
 
321
  with gr.Column(scale=3):
322
  gr.Markdown("## Output", elem_id="output-title")
323
  parsed_output = gr.JSON(label="Parsed Answer")
324
+ annotated_output = gr.Image(label="Annotated Image (if available)", type="pil")
325
 
326
  image_submit.click(
327
  fn=run_florence2_inference,
328
  inputs=[model_choice, image_upload, task_prompt, max_new_tokens, num_beams],
329
+ outputs=[parsed_output, annotated_output]
330
  )
331
 
332
  if __name__ == "__main__":
333
+ demo.queue().launch(debug=True, mcp_server=True, ssr_mode=False, show_error=True)