Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,24 +25,31 @@ MODEL_REGISTRY = {
|
|
| 25 |
"grounding_dino": "grounding_dino_detect_model"
|
| 26 |
}
|
| 27 |
}
|
| 28 |
-
LOADED_MODELS = {}
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
def get_model(task, architecture):
|
| 31 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 32 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
model_name = MODEL_REGISTRY[task][architecture]
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
return LOADED_MODELS[model_name]
|
| 39 |
except KeyError as e:
|
| 40 |
raise gr.Error(f"Model lookup failed. Task: '{task}', Arch: '{architecture}'. Error: {e}")
|
| 41 |
except Exception as e:
|
| 42 |
raise gr.Error(f"Failed to load model. Please check the model name and your connection. Error: {e}")
|
| 43 |
|
| 44 |
# --- Visualization and Drawing Functions ---
|
| 45 |
-
# Note: The global font object has been removed from here.
|
| 46 |
|
| 47 |
def draw_yolo_predictions(image, results, font, color="red"):
|
| 48 |
"""Draws YOLO predictions on an image with a dynamically sized font."""
|
|
@@ -100,23 +107,20 @@ def visualize_embedding(embedding):
|
|
| 100 |
buf.seek(0)
|
| 101 |
return Image.open(buf)
|
| 102 |
|
| 103 |
-
# --- Main Processing Function ---
|
| 104 |
def comprehensive_analysis(image, task, architecture, text_prompt, box_threshold, text_threshold):
|
| 105 |
-
"""Performs the main analysis
|
| 106 |
if image is None:
|
| 107 |
raise gr.Error("Please upload an image first!")
|
| 108 |
|
| 109 |
# Calculate a dynamic font size based on image width.
|
| 110 |
-
# The font size will be 4% of the image width, with a minimum size of 15.
|
| 111 |
dynamic_font_size = max(15, int(image.width * 0.04))
|
| 112 |
try:
|
| 113 |
font = ImageFont.truetype("arial.ttf", dynamic_font_size)
|
| 114 |
except IOError:
|
| 115 |
font = ImageFont.load_default(size=dynamic_font_size)
|
| 116 |
|
| 117 |
-
|
| 118 |
-
architecture = "grounding_dino"
|
| 119 |
-
|
| 120 |
model = get_model(task, architecture)
|
| 121 |
outputs = {"annotated_image": None, "model_info": "", "classes_info": "", "embedding_plot": None}
|
| 122 |
|
|
@@ -126,7 +130,7 @@ def comprehensive_analysis(image, task, architecture, text_prompt, box_threshold
|
|
| 126 |
features = model.extract_features(image)
|
| 127 |
outputs["model_info"] = f"Architecture: {architecture.upper()}\nTask: {task}\nDevice: {model.device}"
|
| 128 |
outputs["classes_info"] = f"Classes: {model.get_classes()}"
|
| 129 |
-
else:
|
| 130 |
if not text_prompt:
|
| 131 |
raise gr.Error("Please provide a text prompt for Zero-Shot Detection.")
|
| 132 |
|
|
@@ -136,17 +140,18 @@ def comprehensive_analysis(image, task, architecture, text_prompt, box_threshold
|
|
| 136 |
box_threshold=box_threshold,
|
| 137 |
text_threshold=text_threshold
|
| 138 |
)
|
| 139 |
-
|
| 140 |
outputs["annotated_image"] = draw_dino_predictions(image, results, font=font)
|
| 141 |
features = model.extract_features(image, text_prompt=text_prompt)
|
| 142 |
-
outputs["model_info"] = f"Architecture:
|
| 143 |
outputs["classes_info"] = f"Prompt: '{text_prompt}'"
|
| 144 |
|
|
|
|
| 145 |
if isinstance(features, dict):
|
| 146 |
outputs["embedding_plot"] = visualize_embedding(features.get('last_hidden_state'))
|
| 147 |
else:
|
| 148 |
outputs["embedding_plot"] = visualize_embedding(features)
|
| 149 |
|
|
|
|
| 150 |
return outputs["annotated_image"], outputs["model_info"], outputs["classes_info"], outputs["embedding_plot"]
|
| 151 |
|
| 152 |
# --- Gradio UI ---
|
|
|
|
| 25 |
"grounding_dino": "grounding_dino_detect_model"
|
| 26 |
}
|
| 27 |
}
|
|
|
|
| 28 |
|
| 29 |
+
# --- CORRECTED MODEL MANAGEMENT ---
|
| 30 |
+
# Caching is removed to prevent errors from stateful models.
|
| 31 |
+
# This function now loads a fresh model for each analysis request.
|
| 32 |
def get_model(task, architecture):
|
| 33 |
+
"""
|
| 34 |
+
Loads a fresh model instance based on user selection.
|
| 35 |
+
This prevents stateful changes from one run affecting the next.
|
| 36 |
+
"""
|
| 37 |
try:
|
| 38 |
+
# For Zero-Shot, the architecture is always 'grounding_dino'
|
| 39 |
+
if task == "Zero-Shot Detection":
|
| 40 |
+
architecture = "grounding_dino"
|
| 41 |
+
|
| 42 |
model_name = MODEL_REGISTRY[task][architecture]
|
| 43 |
+
print(f"Loading a fresh model instance: {model_name}")
|
| 44 |
+
model = ibbi.create_model(model_name, pretrained=True)
|
| 45 |
+
print("Model loaded successfully!")
|
| 46 |
+
return model
|
|
|
|
| 47 |
except KeyError as e:
|
| 48 |
raise gr.Error(f"Model lookup failed. Task: '{task}', Arch: '{architecture}'. Error: {e}")
|
| 49 |
except Exception as e:
|
| 50 |
raise gr.Error(f"Failed to load model. Please check the model name and your connection. Error: {e}")
|
| 51 |
|
| 52 |
# --- Visualization and Drawing Functions ---
|
|
|
|
| 53 |
|
| 54 |
def draw_yolo_predictions(image, results, font, color="red"):
|
| 55 |
"""Draws YOLO predictions on an image with a dynamically sized font."""
|
|
|
|
| 107 |
buf.seek(0)
|
| 108 |
return Image.open(buf)
|
| 109 |
|
| 110 |
+
# --- CORRECTED Main Processing Function ---
|
| 111 |
def comprehensive_analysis(image, task, architecture, text_prompt, box_threshold, text_threshold):
|
| 112 |
+
"""Performs the main analysis with corrected logic."""
|
| 113 |
if image is None:
|
| 114 |
raise gr.Error("Please upload an image first!")
|
| 115 |
|
| 116 |
# Calculate a dynamic font size based on image width.
|
|
|
|
| 117 |
dynamic_font_size = max(15, int(image.width * 0.04))
|
| 118 |
try:
|
| 119 |
font = ImageFont.truetype("arial.ttf", dynamic_font_size)
|
| 120 |
except IOError:
|
| 121 |
font = ImageFont.load_default(size=dynamic_font_size)
|
| 122 |
|
| 123 |
+
# Get a fresh model instance to avoid stateful errors
|
|
|
|
|
|
|
| 124 |
model = get_model(task, architecture)
|
| 125 |
outputs = {"annotated_image": None, "model_info": "", "classes_info": "", "embedding_plot": None}
|
| 126 |
|
|
|
|
| 130 |
features = model.extract_features(image)
|
| 131 |
outputs["model_info"] = f"Architecture: {architecture.upper()}\nTask: {task}\nDevice: {model.device}"
|
| 132 |
outputs["classes_info"] = f"Classes: {model.get_classes()}"
|
| 133 |
+
else: # Zero-Shot Detection
|
| 134 |
if not text_prompt:
|
| 135 |
raise gr.Error("Please provide a text prompt for Zero-Shot Detection.")
|
| 136 |
|
|
|
|
| 140 |
box_threshold=box_threshold,
|
| 141 |
text_threshold=text_threshold
|
| 142 |
)
|
|
|
|
| 143 |
outputs["annotated_image"] = draw_dino_predictions(image, results, font=font)
|
| 144 |
features = model.extract_features(image, text_prompt=text_prompt)
|
| 145 |
+
outputs["model_info"] = f"Architecture: GROUNDING_DINO\nTask: {task}\nDevice: {model.device}\nHF Model ID: {model.model.config._name_or_path}"
|
| 146 |
outputs["classes_info"] = f"Prompt: '{text_prompt}'"
|
| 147 |
|
| 148 |
+
# Process features for visualization
|
| 149 |
if isinstance(features, dict):
|
| 150 |
outputs["embedding_plot"] = visualize_embedding(features.get('last_hidden_state'))
|
| 151 |
else:
|
| 152 |
outputs["embedding_plot"] = visualize_embedding(features)
|
| 153 |
|
| 154 |
+
# Correctly placed return statement ensures all outputs are always returned
|
| 155 |
return outputs["annotated_image"], outputs["model_info"], outputs["classes_info"], outputs["embedding_plot"]
|
| 156 |
|
| 157 |
# --- Gradio UI ---
|