ChaseHan commited on
Commit
1d94259
·
verified ·
1 Parent(s): de9cba4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -28
app.py CHANGED
@@ -14,8 +14,8 @@ MODEL_BASE_ID = "ChaseHan/Latex2Layout-2000-sync"
14
  MODEL_ENHANCED_NAME = "Latex2Layout-Enhanced"
15
  MODEL_ENHANCED_ID = "ChaseHan/Latex2Layout-RL"
16
 
17
- # --- NEW: Add a name for the Mixing mode ---
18
- MODEL_MIXING_NAME = "Mixing (Base + Enhanced)"
19
  MODEL_CHOICES = [MODEL_BASE_NAME, MODEL_ENHANCED_NAME, MODEL_MIXING_NAME]
20
 
21
 
@@ -46,8 +46,6 @@ DEFAULT_PROMPT = (
46
  # --- 2. Load Models and Processor ---
47
  print("Loading models, this will take some time and VRAM...")
48
  try:
49
- # WARNING: Loading two 3B models without quantization requires a large amount of VRAM (>12 GB).
50
- # This may fail on hardware with insufficient memory.
51
  print(f"Loading {MODEL_BASE_NAME}...")
52
  model_base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
53
  MODEL_BASE_ID,
@@ -62,42 +60,84 @@ try:
62
  device_map="auto"
63
  )
64
 
65
- # Processor is the same for both models
66
  processor = AutoProcessor.from_pretrained(MODEL_BASE_ID)
67
  print("All models loaded successfully!")
68
  except Exception as e:
69
  print(f"Error loading models: {e}")
70
  exit()
71
 
72
- # --- NEW: Helper function to calculate Intersection over Union ---
73
  def calculate_iou(boxA, boxB):
74
  """Calculate Intersection over Union (IoU) of two bounding boxes."""
75
- # Determine the coordinates of the intersection rectangle
76
  xA = max(boxA[0], boxB[0])
77
  yA = max(boxA[1], boxB[1])
78
  xB = min(boxA[2], boxB[2])
79
  yB = min(boxA[3], boxB[3])
80
 
81
- # Compute the area of intersection
82
  interArea = max(0, xB - xA) * max(0, yB - yA)
83
-
84
- # Compute the area of both bounding boxes
85
  boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
86
  boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
87
-
88
- # Compute the area of union
89
  unionArea = float(boxAArea + boxBArea - interArea)
90
-
91
- # Return the IoU
92
  return interArea / unionArea if unionArea > 0 else 0
93
 
94
- # --- 3. Core Inference and Visualization Function (MODIFIED) ---
95
- @GPU
96
- def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name: str, prompt: str, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
97
  """
98
- Takes an image and model parameters, runs inference, and returns a visualized image and raw text output.
99
- Supports running a single model or mixing results from two models.
100
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  if input_image is None:
102
  return None, "Please upload an image first."
103
 
@@ -105,7 +145,6 @@ def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name:
105
  image_resized = input_image.resize(TARGET_SIZE)
106
  image_rgba = image_resized.convert("RGBA")
107
 
108
- # --- Nested helper function to run inference on a given model ---
109
  def run_inference(model_to_run, model_name_desc):
110
  progress(0.1, desc=f"Preparing inputs for {model_name_desc}...")
111
  messages = [{"role": "user", "content": [{"type": "image", "image": image_rgba}, {"type": "text", "text": prompt}]}]
@@ -123,20 +162,17 @@ def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name:
123
  json_str = json_match.group(1).strip() if json_match else raw_text.strip()
124
  return json.loads(json_str), raw_text
125
  except (json.JSONDecodeError, AttributeError):
126
- return None, raw_text # Return raw text on failure for debugging
127
 
128
- # --- Main logic: single model or mixing ---
129
  if selected_model_name == MODEL_MIXING_NAME:
130
  base_results, raw_text_base = run_inference(model_base, "Base Model")
131
  enhanced_results, raw_text_enhanced = run_inference(model_enhanced, "Enhanced Model")
132
-
133
  output_text = f"--- Base Model Output ---\n{raw_text_base}\n\n--- Enhanced Model Output ---\n{raw_text_enhanced}"
134
 
135
  if base_results is None or enhanced_results is None:
136
  return image_rgba.convert("RGB"), f"Failed to parse JSON from one or both models:\n\n{output_text}"
137
 
138
- # Merge results based on IoU
139
- progress(0.8, desc="Merging results from both models...")
140
  merged_results = list(base_results)
141
  base_bboxes = [item['bbox_2d'] for item in base_results if 'bbox_2d' in item]
142
 
@@ -154,14 +190,18 @@ def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name:
154
 
155
  results = merged_results
156
  else:
157
- # Run a single model
158
  model = model_base if selected_model_name == MODEL_BASE_NAME else model_enhanced
159
  results, output_text = run_inference(model, selected_model_name)
160
  if results is None:
161
  return image_rgba.convert("RGB"), f"Failed to parse JSON from model output:\n\n{output_text}"
162
 
 
 
 
 
 
163
  # --- Visualization ---
164
- progress(0.9, desc="Parsing and visualizing final results...")
165
  overlay = Image.new('RGBA', image_rgba.size, (255, 255, 255, 0))
166
  draw = ImageDraw.Draw(overlay)
167
 
@@ -219,7 +259,7 @@ with gr.Blocks(theme=gr.themes.Glass(), title="Academic Paper Layout Detection")
219
  choices=MODEL_CHOICES,
220
  value=MODEL_BASE_NAME,
221
  label="Select Model",
222
- info="Choose which model to use for inference. 'Mixing' combines the results of both."
223
  )
224
  prompt_textbox = gr.Textbox(
225
  label="Prompt",
 
14
  MODEL_ENHANCED_NAME = "Latex2Layout-Enhanced"
15
  MODEL_ENHANCED_ID = "ChaseHan/Latex2Layout-RL"
16
 
17
+ # Add a name for the Mixing mode
18
+ MODEL_MIXING_NAME = "Mixing Beta Version(Powerful Mode)"
19
  MODEL_CHOICES = [MODEL_BASE_NAME, MODEL_ENHANCED_NAME, MODEL_MIXING_NAME]
20
 
21
 
 
46
  # --- 2. Load Models and Processor ---
47
  print("Loading models, this will take some time and VRAM...")
48
  try:
 
 
49
  print(f"Loading {MODEL_BASE_NAME}...")
50
  model_base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
51
  MODEL_BASE_ID,
 
60
  device_map="auto"
61
  )
62
 
 
63
  processor = AutoProcessor.from_pretrained(MODEL_BASE_ID)
64
  print("All models loaded successfully!")
65
  except Exception as e:
66
  print(f"Error loading models: {e}")
67
  exit()
68
 
69
+ # --- Helper functions for geometric calculations ---
70
  def calculate_iou(boxA, boxB):
71
  """Calculate Intersection over Union (IoU) of two bounding boxes."""
 
72
  xA = max(boxA[0], boxB[0])
73
  yA = max(boxA[1], boxB[1])
74
  xB = min(boxA[2], boxB[2])
75
  yB = min(boxA[3], boxB[3])
76
 
 
77
  interArea = max(0, xB - xA) * max(0, yB - yA)
 
 
78
  boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
79
  boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
 
 
80
  unionArea = float(boxAArea + boxBArea - interArea)
 
 
81
  return interArea / unionArea if unionArea > 0 else 0
82
 
83
+ def calculate_intersection_area(boxA, boxB):
84
+ """Calculate the absolute intersection area of two bounding boxes."""
85
+ xA = max(boxA[0], boxB[0])
86
+ yA = max(boxA[1], boxB[1])
87
+ xB = min(boxA[2], boxB[2])
88
+ yB = min(boxA[3], boxB[3])
89
+ return max(0, xB - xA) * max(0, yB - yA)
90
+
91
+ # --- NEW: Function to remove nested elements of the same type ---
92
+ def remove_nested_elements(results):
93
  """
94
+ Removes smaller elements that are heavily nested within larger elements of the same label.
95
+ An element is considered nested if >80% of its area is inside the other.
96
  """
97
+ indices_to_remove = set()
98
+ for i in range(len(results)):
99
+ for j in range(len(results)):
100
+ if i == j:
101
+ continue
102
+
103
+ item_i = results[i]
104
+ item_j = results[j]
105
+
106
+ # Rule only applies to elements with the same label
107
+ if item_i.get("label") != item_j.get("label"):
108
+ continue
109
+
110
+ bbox_i = item_i.get("bbox_2d")
111
+ bbox_j = item_j.get("bbox_2d")
112
+
113
+ if not bbox_i or not bbox_j:
114
+ continue
115
+
116
+ area_i = (bbox_i[2] - bbox_i[0]) * (bbox_i[3] - bbox_i[1])
117
+ area_j = (bbox_j[2] - bbox_j[0]) * (bbox_j[3] - bbox_j[1])
118
+
119
+ if area_i == 0 or area_j == 0:
120
+ continue
121
+
122
+ # Identify smaller and larger box
123
+ if area_i < area_j:
124
+ smaller_box, larger_box, smaller_area, smaller_idx = bbox_i, bbox_j, area_i, i
125
+ else:
126
+ smaller_box, larger_box, smaller_area, smaller_idx = bbox_j, bbox_i, area_j, j
127
+
128
+ intersection = calculate_intersection_area(smaller_box, larger_box)
129
+
130
+ # If the smaller box is >80% contained in the larger one, mark it for removal
131
+ if (intersection / smaller_area) > 0.8:
132
+ indices_to_remove.add(smaller_idx)
133
+
134
+ # Return a new list containing only the elements that were not marked for removal
135
+ return [item for idx, item in enumerate(results) if idx not in indices_to_remove]
136
+
137
+
138
+ # --- 3. Core Inference and Visualization Function ---
139
+ @GPU
140
+ def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name: str, prompt: str, progress=gr.Progress(track_tqdm=True)):
141
  if input_image is None:
142
  return None, "Please upload an image first."
143
 
 
145
  image_resized = input_image.resize(TARGET_SIZE)
146
  image_rgba = image_resized.convert("RGBA")
147
 
 
148
  def run_inference(model_to_run, model_name_desc):
149
  progress(0.1, desc=f"Preparing inputs for {model_name_desc}...")
150
  messages = [{"role": "user", "content": [{"type": "image", "image": image_rgba}, {"type": "text", "text": prompt}]}]
 
162
  json_str = json_match.group(1).strip() if json_match else raw_text.strip()
163
  return json.loads(json_str), raw_text
164
  except (json.JSONDecodeError, AttributeError):
165
+ return None, raw_text
166
 
 
167
  if selected_model_name == MODEL_MIXING_NAME:
168
  base_results, raw_text_base = run_inference(model_base, "Base Model")
169
  enhanced_results, raw_text_enhanced = run_inference(model_enhanced, "Enhanced Model")
 
170
  output_text = f"--- Base Model Output ---\n{raw_text_base}\n\n--- Enhanced Model Output ---\n{raw_text_enhanced}"
171
 
172
  if base_results is None or enhanced_results is None:
173
  return image_rgba.convert("RGB"), f"Failed to parse JSON from one or both models:\n\n{output_text}"
174
 
175
+ progress(0.8, desc="Merging results based on IoU...")
 
176
  merged_results = list(base_results)
177
  base_bboxes = [item['bbox_2d'] for item in base_results if 'bbox_2d' in item]
178
 
 
190
 
191
  results = merged_results
192
  else:
 
193
  model = model_base if selected_model_name == MODEL_BASE_NAME else model_enhanced
194
  results, output_text = run_inference(model, selected_model_name)
195
  if results is None:
196
  return image_rgba.convert("RGB"), f"Failed to parse JSON from model output:\n\n{output_text}"
197
 
198
+ # --- NEW: Apply the final post-processing step to remove nested elements ---
199
+ progress(0.85, desc="Cleaning up nested elements...")
200
+ results = remove_nested_elements(results)
201
+
202
+
203
  # --- Visualization ---
204
+ progress(0.9, desc="Visualizing final results...")
205
  overlay = Image.new('RGBA', image_rgba.size, (255, 255, 255, 0))
206
  draw = ImageDraw.Draw(overlay)
207
 
 
259
  choices=MODEL_CHOICES,
260
  value=MODEL_BASE_NAME,
261
  label="Select Model",
262
+ info="Choose which model to use for inference. "
263
  )
264
  prompt_textbox = gr.Textbox(
265
  label="Prompt",