Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,8 +14,8 @@ MODEL_BASE_ID = "ChaseHan/Latex2Layout-2000-sync"
|
|
| 14 |
MODEL_ENHANCED_NAME = "Latex2Layout-Enhanced"
|
| 15 |
MODEL_ENHANCED_ID = "ChaseHan/Latex2Layout-RL"
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
MODEL_MIXING_NAME = "Mixing (
|
| 19 |
MODEL_CHOICES = [MODEL_BASE_NAME, MODEL_ENHANCED_NAME, MODEL_MIXING_NAME]
|
| 20 |
|
| 21 |
|
|
@@ -46,8 +46,6 @@ DEFAULT_PROMPT = (
|
|
| 46 |
# --- 2. Load Models and Processor ---
|
| 47 |
print("Loading models, this will take some time and VRAM...")
|
| 48 |
try:
|
| 49 |
-
# WARNING: Loading two 3B models without quantization requires a large amount of VRAM (>12 GB).
|
| 50 |
-
# This may fail on hardware with insufficient memory.
|
| 51 |
print(f"Loading {MODEL_BASE_NAME}...")
|
| 52 |
model_base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 53 |
MODEL_BASE_ID,
|
|
@@ -62,42 +60,84 @@ try:
|
|
| 62 |
device_map="auto"
|
| 63 |
)
|
| 64 |
|
| 65 |
-
# Processor is the same for both models
|
| 66 |
processor = AutoProcessor.from_pretrained(MODEL_BASE_ID)
|
| 67 |
print("All models loaded successfully!")
|
| 68 |
except Exception as e:
|
| 69 |
print(f"Error loading models: {e}")
|
| 70 |
exit()
|
| 71 |
|
| 72 |
-
# ---
|
| 73 |
def calculate_iou(boxA, boxB):
|
| 74 |
"""Calculate Intersection over Union (IoU) of two bounding boxes."""
|
| 75 |
-
# Determine the coordinates of the intersection rectangle
|
| 76 |
xA = max(boxA[0], boxB[0])
|
| 77 |
yA = max(boxA[1], boxB[1])
|
| 78 |
xB = min(boxA[2], boxB[2])
|
| 79 |
yB = min(boxA[3], boxB[3])
|
| 80 |
|
| 81 |
-
# Compute the area of intersection
|
| 82 |
interArea = max(0, xB - xA) * max(0, yB - yA)
|
| 83 |
-
|
| 84 |
-
# Compute the area of both bounding boxes
|
| 85 |
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
| 86 |
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
| 87 |
-
|
| 88 |
-
# Compute the area of union
|
| 89 |
unionArea = float(boxAArea + boxBArea - interArea)
|
| 90 |
-
|
| 91 |
-
# Return the IoU
|
| 92 |
return interArea / unionArea if unionArea > 0 else 0
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
"""
|
| 98 |
-
|
| 99 |
-
|
| 100 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
if input_image is None:
|
| 102 |
return None, "Please upload an image first."
|
| 103 |
|
|
@@ -105,7 +145,6 @@ def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name:
|
|
| 105 |
image_resized = input_image.resize(TARGET_SIZE)
|
| 106 |
image_rgba = image_resized.convert("RGBA")
|
| 107 |
|
| 108 |
-
# --- Nested helper function to run inference on a given model ---
|
| 109 |
def run_inference(model_to_run, model_name_desc):
|
| 110 |
progress(0.1, desc=f"Preparing inputs for {model_name_desc}...")
|
| 111 |
messages = [{"role": "user", "content": [{"type": "image", "image": image_rgba}, {"type": "text", "text": prompt}]}]
|
|
@@ -123,20 +162,17 @@ def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name:
|
|
| 123 |
json_str = json_match.group(1).strip() if json_match else raw_text.strip()
|
| 124 |
return json.loads(json_str), raw_text
|
| 125 |
except (json.JSONDecodeError, AttributeError):
|
| 126 |
-
return None, raw_text
|
| 127 |
|
| 128 |
-
# --- Main logic: single model or mixing ---
|
| 129 |
if selected_model_name == MODEL_MIXING_NAME:
|
| 130 |
base_results, raw_text_base = run_inference(model_base, "Base Model")
|
| 131 |
enhanced_results, raw_text_enhanced = run_inference(model_enhanced, "Enhanced Model")
|
| 132 |
-
|
| 133 |
output_text = f"--- Base Model Output ---\n{raw_text_base}\n\n--- Enhanced Model Output ---\n{raw_text_enhanced}"
|
| 134 |
|
| 135 |
if base_results is None or enhanced_results is None:
|
| 136 |
return image_rgba.convert("RGB"), f"Failed to parse JSON from one or both models:\n\n{output_text}"
|
| 137 |
|
| 138 |
-
|
| 139 |
-
progress(0.8, desc="Merging results from both models...")
|
| 140 |
merged_results = list(base_results)
|
| 141 |
base_bboxes = [item['bbox_2d'] for item in base_results if 'bbox_2d' in item]
|
| 142 |
|
|
@@ -154,14 +190,18 @@ def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name:
|
|
| 154 |
|
| 155 |
results = merged_results
|
| 156 |
else:
|
| 157 |
-
# Run a single model
|
| 158 |
model = model_base if selected_model_name == MODEL_BASE_NAME else model_enhanced
|
| 159 |
results, output_text = run_inference(model, selected_model_name)
|
| 160 |
if results is None:
|
| 161 |
return image_rgba.convert("RGB"), f"Failed to parse JSON from model output:\n\n{output_text}"
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
# --- Visualization ---
|
| 164 |
-
progress(0.9, desc="
|
| 165 |
overlay = Image.new('RGBA', image_rgba.size, (255, 255, 255, 0))
|
| 166 |
draw = ImageDraw.Draw(overlay)
|
| 167 |
|
|
@@ -219,7 +259,7 @@ with gr.Blocks(theme=gr.themes.Glass(), title="Academic Paper Layout Detection")
|
|
| 219 |
choices=MODEL_CHOICES,
|
| 220 |
value=MODEL_BASE_NAME,
|
| 221 |
label="Select Model",
|
| 222 |
-
info="Choose which model to use for inference.
|
| 223 |
)
|
| 224 |
prompt_textbox = gr.Textbox(
|
| 225 |
label="Prompt",
|
|
|
|
| 14 |
MODEL_ENHANCED_NAME = "Latex2Layout-Enhanced"
|
| 15 |
MODEL_ENHANCED_ID = "ChaseHan/Latex2Layout-RL"
|
| 16 |
|
| 17 |
+
# Add a name for the Mixing mode
|
| 18 |
+
MODEL_MIXING_NAME = "Mixing Beta Version(Powerful Mode)"
|
| 19 |
MODEL_CHOICES = [MODEL_BASE_NAME, MODEL_ENHANCED_NAME, MODEL_MIXING_NAME]
|
| 20 |
|
| 21 |
|
|
|
|
| 46 |
# --- 2. Load Models and Processor ---
|
| 47 |
print("Loading models, this will take some time and VRAM...")
|
| 48 |
try:
|
|
|
|
|
|
|
| 49 |
print(f"Loading {MODEL_BASE_NAME}...")
|
| 50 |
model_base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 51 |
MODEL_BASE_ID,
|
|
|
|
| 60 |
device_map="auto"
|
| 61 |
)
|
| 62 |
|
|
|
|
| 63 |
processor = AutoProcessor.from_pretrained(MODEL_BASE_ID)
|
| 64 |
print("All models loaded successfully!")
|
| 65 |
except Exception as e:
|
| 66 |
print(f"Error loading models: {e}")
|
| 67 |
exit()
|
| 68 |
|
| 69 |
+
# --- Helper functions for geometric calculations ---
|
| 70 |
def calculate_iou(boxA, boxB):
|
| 71 |
"""Calculate Intersection over Union (IoU) of two bounding boxes."""
|
|
|
|
| 72 |
xA = max(boxA[0], boxB[0])
|
| 73 |
yA = max(boxA[1], boxB[1])
|
| 74 |
xB = min(boxA[2], boxB[2])
|
| 75 |
yB = min(boxA[3], boxB[3])
|
| 76 |
|
|
|
|
| 77 |
interArea = max(0, xB - xA) * max(0, yB - yA)
|
|
|
|
|
|
|
| 78 |
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
| 79 |
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
|
|
|
|
|
|
| 80 |
unionArea = float(boxAArea + boxBArea - interArea)
|
|
|
|
|
|
|
| 81 |
return interArea / unionArea if unionArea > 0 else 0
|
| 82 |
|
| 83 |
+
def calculate_intersection_area(boxA, boxB):
|
| 84 |
+
"""Calculate the absolute intersection area of two bounding boxes."""
|
| 85 |
+
xA = max(boxA[0], boxB[0])
|
| 86 |
+
yA = max(boxA[1], boxB[1])
|
| 87 |
+
xB = min(boxA[2], boxB[2])
|
| 88 |
+
yB = min(boxA[3], boxB[3])
|
| 89 |
+
return max(0, xB - xA) * max(0, yB - yA)
|
| 90 |
+
|
| 91 |
+
# --- NEW: Function to remove nested elements of the same type ---
|
| 92 |
+
def remove_nested_elements(results):
|
| 93 |
"""
|
| 94 |
+
Removes smaller elements that are heavily nested within larger elements of the same label.
|
| 95 |
+
An element is considered nested if >80% of its area is inside the other.
|
| 96 |
"""
|
| 97 |
+
indices_to_remove = set()
|
| 98 |
+
for i in range(len(results)):
|
| 99 |
+
for j in range(len(results)):
|
| 100 |
+
if i == j:
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
item_i = results[i]
|
| 104 |
+
item_j = results[j]
|
| 105 |
+
|
| 106 |
+
# Rule only applies to elements with the same label
|
| 107 |
+
if item_i.get("label") != item_j.get("label"):
|
| 108 |
+
continue
|
| 109 |
+
|
| 110 |
+
bbox_i = item_i.get("bbox_2d")
|
| 111 |
+
bbox_j = item_j.get("bbox_2d")
|
| 112 |
+
|
| 113 |
+
if not bbox_i or not bbox_j:
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
area_i = (bbox_i[2] - bbox_i[0]) * (bbox_i[3] - bbox_i[1])
|
| 117 |
+
area_j = (bbox_j[2] - bbox_j[0]) * (bbox_j[3] - bbox_j[1])
|
| 118 |
+
|
| 119 |
+
if area_i == 0 or area_j == 0:
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
# Identify smaller and larger box
|
| 123 |
+
if area_i < area_j:
|
| 124 |
+
smaller_box, larger_box, smaller_area, smaller_idx = bbox_i, bbox_j, area_i, i
|
| 125 |
+
else:
|
| 126 |
+
smaller_box, larger_box, smaller_area, smaller_idx = bbox_j, bbox_i, area_j, j
|
| 127 |
+
|
| 128 |
+
intersection = calculate_intersection_area(smaller_box, larger_box)
|
| 129 |
+
|
| 130 |
+
# If the smaller box is >80% contained in the larger one, mark it for removal
|
| 131 |
+
if (intersection / smaller_area) > 0.8:
|
| 132 |
+
indices_to_remove.add(smaller_idx)
|
| 133 |
+
|
| 134 |
+
# Return a new list containing only the elements that were not marked for removal
|
| 135 |
+
return [item for idx, item in enumerate(results) if idx not in indices_to_remove]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# --- 3. Core Inference and Visualization Function ---
|
| 139 |
+
@GPU
|
| 140 |
+
def analyze_and_visualize_layout(input_image: Image.Image, selected_model_name: str, prompt: str, progress=gr.Progress(track_tqdm=True)):
|
| 141 |
if input_image is None:
|
| 142 |
return None, "Please upload an image first."
|
| 143 |
|
|
|
|
| 145 |
image_resized = input_image.resize(TARGET_SIZE)
|
| 146 |
image_rgba = image_resized.convert("RGBA")
|
| 147 |
|
|
|
|
| 148 |
def run_inference(model_to_run, model_name_desc):
|
| 149 |
progress(0.1, desc=f"Preparing inputs for {model_name_desc}...")
|
| 150 |
messages = [{"role": "user", "content": [{"type": "image", "image": image_rgba}, {"type": "text", "text": prompt}]}]
|
|
|
|
| 162 |
json_str = json_match.group(1).strip() if json_match else raw_text.strip()
|
| 163 |
return json.loads(json_str), raw_text
|
| 164 |
except (json.JSONDecodeError, AttributeError):
|
| 165 |
+
return None, raw_text
|
| 166 |
|
|
|
|
| 167 |
if selected_model_name == MODEL_MIXING_NAME:
|
| 168 |
base_results, raw_text_base = run_inference(model_base, "Base Model")
|
| 169 |
enhanced_results, raw_text_enhanced = run_inference(model_enhanced, "Enhanced Model")
|
|
|
|
| 170 |
output_text = f"--- Base Model Output ---\n{raw_text_base}\n\n--- Enhanced Model Output ---\n{raw_text_enhanced}"
|
| 171 |
|
| 172 |
if base_results is None or enhanced_results is None:
|
| 173 |
return image_rgba.convert("RGB"), f"Failed to parse JSON from one or both models:\n\n{output_text}"
|
| 174 |
|
| 175 |
+
progress(0.8, desc="Merging results based on IoU...")
|
|
|
|
| 176 |
merged_results = list(base_results)
|
| 177 |
base_bboxes = [item['bbox_2d'] for item in base_results if 'bbox_2d' in item]
|
| 178 |
|
|
|
|
| 190 |
|
| 191 |
results = merged_results
|
| 192 |
else:
|
|
|
|
| 193 |
model = model_base if selected_model_name == MODEL_BASE_NAME else model_enhanced
|
| 194 |
results, output_text = run_inference(model, selected_model_name)
|
| 195 |
if results is None:
|
| 196 |
return image_rgba.convert("RGB"), f"Failed to parse JSON from model output:\n\n{output_text}"
|
| 197 |
|
| 198 |
+
# --- NEW: Apply the final post-processing step to remove nested elements ---
|
| 199 |
+
progress(0.85, desc="Cleaning up nested elements...")
|
| 200 |
+
results = remove_nested_elements(results)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
# --- Visualization ---
|
| 204 |
+
progress(0.9, desc="Visualizing final results...")
|
| 205 |
overlay = Image.new('RGBA', image_rgba.size, (255, 255, 255, 0))
|
| 206 |
draw = ImageDraw.Draw(overlay)
|
| 207 |
|
|
|
|
| 259 |
choices=MODEL_CHOICES,
|
| 260 |
value=MODEL_BASE_NAME,
|
| 261 |
label="Select Model",
|
| 262 |
+
info="Choose which model to use for inference. "
|
| 263 |
)
|
| 264 |
prompt_textbox = gr.Textbox(
|
| 265 |
label="Prompt",
|