ChaseHan commited on
Commit
de9cba4
·
verified ·
1 Parent(s): 56e25bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -89
app.py CHANGED
@@ -5,30 +5,28 @@ from PIL import Image, ImageDraw, ImageFont
5
  import json
6
  import re
7
  from spaces import GPU
8
- from peft import PeftModel
9
 
10
  # --- 1. Configurations and Constants ---
11
  # Define user-facing names and Hugging Face IDs for the models
12
  MODEL_BASE_NAME = "Latex2Layout-Base"
13
  MODEL_BASE_ID = "ChaseHan/Latex2Layout-2000-sync"
14
 
15
- MODEL_ENHANCED_NAME = "Qwen2.5-VL + GRPO LoRA (Merged)"
16
- MODEL_ENHANCED_BASE_ID = "ZelongWang/Qwen2.5-VL-3B-Instruct-DocOD-2"
17
- MODEL_ENHANCED_LORA_ID = "ZelongWang/Qwen2.5-VL-3B-GRPO-lora-pdf-v3"
18
- LORA_CHECKPOINT_FOLDER = "checkpoint-525" # Subfolder containing the adapter
19
 
20
  # --- NEW: Add a name for the Mixing mode ---
21
  MODEL_MIXING_NAME = "Mixing (Base + Enhanced)"
22
-
23
  MODEL_CHOICES = [MODEL_BASE_NAME, MODEL_ENHANCED_NAME, MODEL_MIXING_NAME]
24
 
 
25
  # Target image size for model input
26
  TARGET_SIZE = (924, 1204)
27
 
28
  # Visualization Style Constants
29
  OUTLINE_WIDTH = 3
 
30
  LABEL_COLORS = {
31
- "title": (255, 82, 82, 90), # Red
32
  "abstract": (46, 204, 113, 90), # Green
33
  "heading": (52, 152, 219, 90), # Blue
34
  "footnote": (241, 196, 15, 90), # Yellow
@@ -48,46 +46,30 @@ DEFAULT_PROMPT = (
48
  # --- 2. Load Models and Processor ---
49
  print("Loading models, this will take some time and VRAM...")
50
  try:
51
- # Load the original base model
52
- print(f"Loading base model: {MODEL_BASE_NAME}...")
 
53
  model_base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
54
  MODEL_BASE_ID,
55
  torch_dtype=torch.float16,
56
  device_map="auto"
57
  )
58
 
59
- # Load and merge the new enhanced model directly from the Hub
60
- print(f"Loading enhanced model base: {MODEL_ENHANCED_BASE_ID}...")
61
- # Step 1: Load the new base model
62
  model_enhanced = Qwen2_5_VLForConditionalGeneration.from_pretrained(
63
- MODEL_ENHANCED_BASE_ID,
64
- torch_dtype=torch.bfloat16,
65
- device_map="auto",
66
- )
67
-
68
- print(f"Loading LoRA adapter online from: {MODEL_ENHANCED_LORA_ID}...")
69
- # Step 2: Load Peft adapter directly from the Hub, specifying the subfolder
70
- model_enhanced = PeftModel.from_pretrained(
71
- model_enhanced,
72
- MODEL_ENHANCED_LORA_ID,
73
- subfolder=LORA_CHECKPOINT_FOLDER,
74
  device_map="auto"
75
  )
76
-
77
- # Step 3: Merge the adapter weights and unload the PeftModel
78
- print("Merging LoRA adapter...")
79
- model_enhanced = model_enhanced.merge_and_unload()
80
- print(f"Successfully loaded and merged model: {MODEL_ENHANCED_NAME}")
81
 
82
- # Load processor
83
  processor = AutoProcessor.from_pretrained(MODEL_BASE_ID)
84
- print("All models and processor loaded successfully!")
85
  except Exception as e:
86
  print(f"Error loading models: {e}")
87
  exit()
88
 
89
- # --- 3. Core Inference, Merging, and Visualization ---
90
-
91
  def calculate_iou(boxA, boxB):
92
  """Calculate Intersection over Union (IoU) of two bounding boxes."""
93
  # Determine the coordinates of the intersection rectangle
@@ -109,8 +91,9 @@ def calculate_iou(boxA, boxB):
109
  # Return the IoU
110
  return interArea / unionArea if unionArea > 0 else 0
111
 
 
112
  @GPU
113
- def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name: str, prompt: str, use_greedy: bool, temperature: float, top_p: float, progress=gr.Progress(track_tqdm=True)):
114
  """
115
  Takes an image and model parameters, runs inference, and returns a visualized image and raw text output.
116
  Supports running a single model or mixing results from two models.
@@ -119,50 +102,40 @@ def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name:
119
  return None, "Please upload an image first."
120
 
121
  progress(0, desc="Resizing image...")
122
- image = input_image.resize(TARGET_SIZE).convert("RGBA")
 
123
 
124
- # --- Nested function to run inference on a given model ---
125
  def run_inference(model_to_run, model_name_desc):
126
  progress(0.1, desc=f"Preparing inputs for {model_name_desc}...")
127
- messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
128
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
129
- inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt").to(model_to_run.device)
130
-
131
- gen_kwargs = {"max_new_tokens": 4096}
132
- if use_greedy:
133
- gen_kwargs["do_sample"] = False
134
- else:
135
- gen_kwargs["do_sample"] = True
136
- gen_kwargs["temperature"] = temperature
137
- gen_kwargs["top_p"] = top_p
138
-
139
  progress(0.5, desc=f"Generating layout data with {model_name_desc}...")
140
  with torch.no_grad():
141
- output_ids = model_to_run.generate(**inputs, **gen_kwargs)
142
 
143
  raw_text = processor.batch_decode(output_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
144
 
145
  try:
146
  json_match = re.search(r"```json(.*?)```", raw_text, re.DOTALL)
147
  json_str = json_match.group(1).strip() if json_match else raw_text.strip()
148
- parsed_results = json.loads(json_str)
149
- return parsed_results, raw_text
150
  except (json.JSONDecodeError, AttributeError):
151
- # Return raw text on failure for debugging
152
- return None, raw_text
153
 
154
  # --- Main logic: single model or mixing ---
155
  if selected_model_name == MODEL_MIXING_NAME:
156
- # Run both models
157
  base_results, raw_text_base = run_inference(model_base, "Base Model")
158
  enhanced_results, raw_text_enhanced = run_inference(model_enhanced, "Enhanced Model")
159
 
160
  output_text = f"--- Base Model Output ---\n{raw_text_base}\n\n--- Enhanced Model Output ---\n{raw_text_enhanced}"
161
 
162
  if base_results is None or enhanced_results is None:
163
- return image.convert("RGB"), f"Failed to parse JSON from one or both models:\n\n{output_text}"
164
 
165
- # Merge results
166
  progress(0.8, desc="Merging results from both models...")
167
  merged_results = list(base_results)
168
  base_bboxes = [item['bbox_2d'] for item in base_results if 'bbox_2d' in item]
@@ -172,8 +145,7 @@ def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name:
172
 
173
  is_duplicate = False
174
  for base_bbox in base_bboxes:
175
- iou = calculate_iou(enhanced_item['bbox_2d'], base_bbox)
176
- if iou > 0.5: # IoU threshold for duplication
177
  is_duplicate = True
178
  break
179
 
@@ -181,17 +153,16 @@ def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name:
181
  merged_results.append(enhanced_item)
182
 
183
  results = merged_results
184
-
185
  else:
186
  # Run a single model
187
  model = model_base if selected_model_name == MODEL_BASE_NAME else model_enhanced
188
  results, output_text = run_inference(model, selected_model_name)
189
  if results is None:
190
- return image.convert("RGB"), f"Failed to parse JSON from model output:\n\n{output_text}"
191
 
192
  # --- Visualization ---
193
- progress(0.9, desc="Visualizing final results...")
194
- overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
195
  draw = ImageDraw.Draw(overlay)
196
 
197
  try:
@@ -200,9 +171,7 @@ def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name:
200
  font = ImageFont.load_default()
201
 
202
  for item in sorted(results, key=lambda x: x.get("order", 999)):
203
- bbox = item.get("bbox_2d")
204
- label = item.get("label", "other")
205
- order = item.get("order", "")
206
  if not bbox or len(bbox) != 4: continue
207
 
208
  fill_color_rgba = LABEL_COLORS.get(label, LABEL_COLORS["other"])
@@ -216,20 +185,23 @@ def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name:
216
  draw.rectangle(tag_bg_box, fill=solid_color_rgb)
217
  draw.text((bbox[0] + 5, bbox[1] + 3), tag_text, font=font, fill="white")
218
 
219
- return Image.alpha_composite(image, overlay).convert("RGB"), output_text
 
 
220
 
221
  def clear_outputs():
 
222
  return None, None
223
 
224
- def toggle_sampling_params(use_greedy):
225
- """Updates visibility of temperature and top-p sliders."""
226
- is_visible = not use_greedy
227
- return gr.update(visible=is_visible)
228
-
229
  # --- 4. Gradio User Interface ---
230
  with gr.Blocks(theme=gr.themes.Glass(), title="Academic Paper Layout Detection") as demo:
 
231
  gr.Markdown("# 📄 Academic Paper Layout Detection")
232
- gr.Markdown("Welcome! This tool uses Qwen2.5-VL models to detect layout components in academic papers. You can choose the **Latex2Layout** model, an **Enhanced** version, or **Mix** the results of both.")
 
 
 
 
233
  gr.Markdown("<hr>")
234
 
235
  with gr.Row():
@@ -241,38 +213,40 @@ with gr.Blocks(theme=gr.themes.Glass(), title="Academic Paper Layout Detection")
241
  with gr.Row():
242
  analyze_btn = gr.Button("✨ Analyze Layout", variant="primary", scale=1)
243
 
 
244
  with gr.Accordion("Advanced Settings", open=False):
245
  model_selector = gr.Radio(
246
- choices=MODEL_CHOICES,
247
- value=MODEL_MIXING_NAME, # Default to the new mixing mode
248
- label="Select Model"
 
 
 
 
 
 
 
249
  )
250
- prompt_textbox = gr.Textbox(label="Prompt", value=DEFAULT_PROMPT, lines=5)
251
-
252
- greedy_checkbox = gr.Checkbox(label="Use Greedy Decoding", value=True, info="Faster and deterministic. Uncheck to enable Temperature and Top-p.")
253
-
254
- temp_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=0.7, label="Temperature")
255
- top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.9, label="Top-p")
256
 
257
- output_text = gr.Textbox(label="Model Raw Output", lines=10, interactive=False, visible=True)
258
- gr.Examples(examples=[["1.png"], ["2.png"], ["12.png"], ["13.png"], ["14.png"], ["11.png"], ["3.png"], ["7.png"], ["8.png"]], inputs=[input_image], label="Examples (Click to Run)")
259
- gr.Markdown("<p style='text-align:center; color:grey;'>Powered by the Latex2Layout dataset by Feijiang Han</p>")
 
 
 
 
 
 
260
 
261
  # --- Event Handlers ---
262
  analyze_btn.click(
263
  fn=analyze_and_visualize_layout,
264
- inputs=[input_image, model_selector, prompt_textbox, greedy_checkbox, temp_slider, top_p_slider],
265
  outputs=[output_image, output_text]
266
  )
267
 
268
  input_image.upload(fn=clear_outputs, inputs=None, outputs=[output_image, output_text])
269
  input_image.clear(fn=clear_outputs, inputs=None, outputs=[output_image, output_text])
270
-
271
- greedy_checkbox.change(
272
- fn=toggle_sampling_params,
273
- inputs=greedy_checkbox,
274
- outputs=[sampling_params]
275
- )
276
 
277
  # --- 5. Launch the Application ---
278
  if __name__ == "__main__":
 
5
  import json
6
  import re
7
  from spaces import GPU
 
8
 
9
  # --- 1. Configurations and Constants ---
10
  # Define user-facing names and Hugging Face IDs for the models
11
  MODEL_BASE_NAME = "Latex2Layout-Base"
12
  MODEL_BASE_ID = "ChaseHan/Latex2Layout-2000-sync"
13
 
14
+ MODEL_ENHANCED_NAME = "Latex2Layout-Enhanced"
15
+ MODEL_ENHANCED_ID = "ChaseHan/Latex2Layout-RL"
 
 
16
 
17
  # --- NEW: Add a name for the Mixing mode ---
18
  MODEL_MIXING_NAME = "Mixing (Base + Enhanced)"
 
19
  MODEL_CHOICES = [MODEL_BASE_NAME, MODEL_ENHANCED_NAME, MODEL_MIXING_NAME]
20
 
21
+
22
  # Target image size for model input
23
  TARGET_SIZE = (924, 1204)
24
 
25
  # Visualization Style Constants
26
  OUTLINE_WIDTH = 3
27
+ # Color mapping for different layout regions (RGBA for transparency)
28
  LABEL_COLORS = {
29
+ "title": (255, 82, 82, 90), # Red
30
  "abstract": (46, 204, 113, 90), # Green
31
  "heading": (52, 152, 219, 90), # Blue
32
  "footnote": (241, 196, 15, 90), # Yellow
 
46
  # --- 2. Load Models and Processor ---
47
  print("Loading models, this will take some time and VRAM...")
48
  try:
49
+ # WARNING: Loading two 3B models without quantization requires a large amount of VRAM (>12 GB).
50
+ # This may fail on hardware with insufficient memory.
51
+ print(f"Loading {MODEL_BASE_NAME}...")
52
  model_base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
53
  MODEL_BASE_ID,
54
  torch_dtype=torch.float16,
55
  device_map="auto"
56
  )
57
 
58
+ print(f"Loading {MODEL_ENHANCED_NAME}...")
 
 
59
  model_enhanced = Qwen2_5_VLForConditionalGeneration.from_pretrained(
60
+ MODEL_ENHANCED_ID,
61
+ torch_dtype=torch.float16,
 
 
 
 
 
 
 
 
 
62
  device_map="auto"
63
  )
 
 
 
 
 
64
 
65
+ # Processor is the same for both models
66
  processor = AutoProcessor.from_pretrained(MODEL_BASE_ID)
67
+ print("All models loaded successfully!")
68
  except Exception as e:
69
  print(f"Error loading models: {e}")
70
  exit()
71
 
72
+ # --- NEW: Helper function to calculate Intersection over Union ---
 
73
  def calculate_iou(boxA, boxB):
74
  """Calculate Intersection over Union (IoU) of two bounding boxes."""
75
  # Determine the coordinates of the intersection rectangle
 
91
  # Return the IoU
92
  return interArea / unionArea if unionArea > 0 else 0
93
 
94
+ # --- 3. Core Inference and Visualization Function (MODIFIED) ---
95
  @GPU
96
+ def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name: str, prompt: str, progress=gr.Progress(track_tqdm=True)):
97
  """
98
  Takes an image and model parameters, runs inference, and returns a visualized image and raw text output.
99
  Supports running a single model or mixing results from two models.
 
102
  return None, "Please upload an image first."
103
 
104
  progress(0, desc="Resizing image...")
105
+ image_resized = input_image.resize(TARGET_SIZE)
106
+ image_rgba = image_resized.convert("RGBA")
107
 
108
+ # --- Nested helper function to run inference on a given model ---
109
  def run_inference(model_to_run, model_name_desc):
110
  progress(0.1, desc=f"Preparing inputs for {model_name_desc}...")
111
+ messages = [{"role": "user", "content": [{"type": "image", "image": image_rgba}, {"type": "text", "text": prompt}]}]
112
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
113
+ inputs = processor(text=[text], images=[image_rgba], padding=True, return_tensors="pt").to(model_to_run.device)
114
+
 
 
 
 
 
 
 
 
115
  progress(0.5, desc=f"Generating layout data with {model_name_desc}...")
116
  with torch.no_grad():
117
+ output_ids = model_to_run.generate(**inputs, max_new_tokens=4096, do_sample=False)
118
 
119
  raw_text = processor.batch_decode(output_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
120
 
121
  try:
122
  json_match = re.search(r"```json(.*?)```", raw_text, re.DOTALL)
123
  json_str = json_match.group(1).strip() if json_match else raw_text.strip()
124
+ return json.loads(json_str), raw_text
 
125
  except (json.JSONDecodeError, AttributeError):
126
+ return None, raw_text # Return raw text on failure for debugging
 
127
 
128
  # --- Main logic: single model or mixing ---
129
  if selected_model_name == MODEL_MIXING_NAME:
 
130
  base_results, raw_text_base = run_inference(model_base, "Base Model")
131
  enhanced_results, raw_text_enhanced = run_inference(model_enhanced, "Enhanced Model")
132
 
133
  output_text = f"--- Base Model Output ---\n{raw_text_base}\n\n--- Enhanced Model Output ---\n{raw_text_enhanced}"
134
 
135
  if base_results is None or enhanced_results is None:
136
+ return image_rgba.convert("RGB"), f"Failed to parse JSON from one or both models:\n\n{output_text}"
137
 
138
+ # Merge results based on IoU
139
  progress(0.8, desc="Merging results from both models...")
140
  merged_results = list(base_results)
141
  base_bboxes = [item['bbox_2d'] for item in base_results if 'bbox_2d' in item]
 
145
 
146
  is_duplicate = False
147
  for base_bbox in base_bboxes:
148
+ if calculate_iou(enhanced_item['bbox_2d'], base_bbox) > 0.5:
 
149
  is_duplicate = True
150
  break
151
 
 
153
  merged_results.append(enhanced_item)
154
 
155
  results = merged_results
 
156
  else:
157
  # Run a single model
158
  model = model_base if selected_model_name == MODEL_BASE_NAME else model_enhanced
159
  results, output_text = run_inference(model, selected_model_name)
160
  if results is None:
161
+ return image_rgba.convert("RGB"), f"Failed to parse JSON from model output:\n\n{output_text}"
162
 
163
  # --- Visualization ---
164
+ progress(0.9, desc="Parsing and visualizing final results...")
165
+ overlay = Image.new('RGBA', image_rgba.size, (255, 255, 255, 0))
166
  draw = ImageDraw.Draw(overlay)
167
 
168
  try:
 
171
  font = ImageFont.load_default()
172
 
173
  for item in sorted(results, key=lambda x: x.get("order", 999)):
174
+ bbox, label, order = item.get("bbox_2d"), item.get("label", "other"), item.get("order", "")
 
 
175
  if not bbox or len(bbox) != 4: continue
176
 
177
  fill_color_rgba = LABEL_COLORS.get(label, LABEL_COLORS["other"])
 
185
  draw.rectangle(tag_bg_box, fill=solid_color_rgb)
186
  draw.text((bbox[0] + 5, bbox[1] + 3), tag_text, font=font, fill="white")
187
 
188
+ visualized_image = Image.alpha_composite(image_rgba, overlay).convert("RGB")
189
+ return visualized_image, output_text
190
+
191
 
192
  def clear_outputs():
193
+ """Helper function to clear the output fields."""
194
  return None, None
195
 
 
 
 
 
 
196
  # --- 4. Gradio User Interface ---
197
  with gr.Blocks(theme=gr.themes.Glass(), title="Academic Paper Layout Detection") as demo:
198
+
199
  gr.Markdown("# 📄 Academic Paper Layout Detection")
200
+ gr.Markdown(
201
+ "Welcome! This tool uses a Qwen2.5-VL-3B-Instruct model fine-tuned on our Latex2Layout annotated layout dataset to identify layout regions in academic papers. "
202
+ "Upload a document image to begin."
203
+ "\n> **Please note:** All uploaded images are automatically resized to 924x1204 pixels to meet the model's input requirements."
204
+ )
205
  gr.Markdown("<hr>")
206
 
207
  with gr.Row():
 
213
  with gr.Row():
214
  analyze_btn = gr.Button("✨ Analyze Layout", variant="primary", scale=1)
215
 
216
+ # --- Advanced Settings Panel ---
217
  with gr.Accordion("Advanced Settings", open=False):
218
  model_selector = gr.Radio(
219
+ choices=MODEL_CHOICES,
220
+ value=MODEL_BASE_NAME,
221
+ label="Select Model",
222
+ info="Choose which model to use for inference. 'Mixing' combines the results of both."
223
+ )
224
+ prompt_textbox = gr.Textbox(
225
+ label="Prompt",
226
+ value=DEFAULT_PROMPT,
227
+ lines=5,
228
+ info="The prompt used to instruct the model."
229
  )
 
 
 
 
 
 
230
 
231
+ output_text = gr.Textbox(label="Model Raw Output", lines=8, interactive=False, visible=True)
232
+
233
+ gr.Examples(
234
+ examples=[["1.png"], ["2.png"], ["12.png"], ["13.png"], ["14.png"], ["11.png"], ["3.png"], ["7.png"], ["8.png"]],
235
+ inputs=[input_image],
236
+ label="Examples (Click to Run)",
237
+ )
238
+
239
+ gr.Markdown("<p style='text-align:center; color:grey;'>Powered by the Latex2Layout dataset generated by Feijiang Han</p>")
240
 
241
  # --- Event Handlers ---
242
  analyze_btn.click(
243
  fn=analyze_and_visualize_layout,
244
+ inputs=[input_image, model_selector, prompt_textbox],
245
  outputs=[output_image, output_text]
246
  )
247
 
248
  input_image.upload(fn=clear_outputs, inputs=None, outputs=[output_image, output_text])
249
  input_image.clear(fn=clear_outputs, inputs=None, outputs=[output_image, output_text])
 
 
 
 
 
 
250
 
251
  # --- 5. Launch the Application ---
252
  if __name__ == "__main__":