ChristopherMarais commited on
Commit
929de5c
·
verified ·
1 Parent(s): f4a6ba2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -25,24 +25,31 @@ MODEL_REGISTRY = {
25
  "grounding_dino": "grounding_dino_detect_model"
26
  }
27
  }
28
- LOADED_MODELS = {}
29
 
 
 
 
30
  def get_model(task, architecture):
31
- """Lazily loads a model based on user selection and caches it."""
 
 
 
32
  try:
 
 
 
 
33
  model_name = MODEL_REGISTRY[task][architecture]
34
- if model_name not in LOADED_MODELS:
35
- print(f"Loading model for the first time: {model_name}")
36
- LOADED_MODELS[model_name] = ibbi.create_model(model_name, pretrained=True)
37
- print("Model loaded successfully!")
38
- return LOADED_MODELS[model_name]
39
  except KeyError as e:
40
  raise gr.Error(f"Model lookup failed. Task: '{task}', Arch: '{architecture}'. Error: {e}")
41
  except Exception as e:
42
  raise gr.Error(f"Failed to load model. Please check the model name and your connection. Error: {e}")
43
 
44
  # --- Visualization and Drawing Functions ---
45
- # Note: The global font object has been removed from here.
46
 
47
  def draw_yolo_predictions(image, results, font, color="red"):
48
  """Draws YOLO predictions on an image with a dynamically sized font."""
@@ -100,23 +107,20 @@ def visualize_embedding(embedding):
100
  buf.seek(0)
101
  return Image.open(buf)
102
 
103
- # --- Main Processing Function ---
104
  def comprehensive_analysis(image, task, architecture, text_prompt, box_threshold, text_threshold):
105
- """Performs the main analysis, including dynamic font calculation."""
106
  if image is None:
107
  raise gr.Error("Please upload an image first!")
108
 
109
  # Calculate a dynamic font size based on image width.
110
- # The font size will be 4% of the image width, with a minimum size of 15.
111
  dynamic_font_size = max(15, int(image.width * 0.04))
112
  try:
113
  font = ImageFont.truetype("arial.ttf", dynamic_font_size)
114
  except IOError:
115
  font = ImageFont.load_default(size=dynamic_font_size)
116
 
117
- if task == "Zero-Shot Detection":
118
- architecture = "grounding_dino"
119
-
120
  model = get_model(task, architecture)
121
  outputs = {"annotated_image": None, "model_info": "", "classes_info": "", "embedding_plot": None}
122
 
@@ -126,7 +130,7 @@ def comprehensive_analysis(image, task, architecture, text_prompt, box_threshold
126
  features = model.extract_features(image)
127
  outputs["model_info"] = f"Architecture: {architecture.upper()}\nTask: {task}\nDevice: {model.device}"
128
  outputs["classes_info"] = f"Classes: {model.get_classes()}"
129
- else: # Zero-Shot Detection
130
  if not text_prompt:
131
  raise gr.Error("Please provide a text prompt for Zero-Shot Detection.")
132
 
@@ -136,17 +140,18 @@ def comprehensive_analysis(image, task, architecture, text_prompt, box_threshold
136
  box_threshold=box_threshold,
137
  text_threshold=text_threshold
138
  )
139
-
140
  outputs["annotated_image"] = draw_dino_predictions(image, results, font=font)
141
  features = model.extract_features(image, text_prompt=text_prompt)
142
- outputs["model_info"] = f"Architecture: {architecture.upper()}\nTask: {task}\nDevice: {model.device}\nHF Model ID: {model.model.config._name_or_path}"
143
  outputs["classes_info"] = f"Prompt: '{text_prompt}'"
144
 
 
145
  if isinstance(features, dict):
146
  outputs["embedding_plot"] = visualize_embedding(features.get('last_hidden_state'))
147
  else:
148
  outputs["embedding_plot"] = visualize_embedding(features)
149
 
 
150
  return outputs["annotated_image"], outputs["model_info"], outputs["classes_info"], outputs["embedding_plot"]
151
 
152
  # --- Gradio UI ---
 
25
  "grounding_dino": "grounding_dino_detect_model"
26
  }
27
  }
 
28
 
29
+ # --- CORRECTED MODEL MANAGEMENT ---
30
+ # Caching is removed to prevent errors from stateful models.
31
+ # This function now loads a fresh model for each analysis request.
32
  def get_model(task, architecture):
33
+ """
34
+ Loads a fresh model instance based on user selection.
35
+ This prevents stateful changes from one run affecting the next.
36
+ """
37
  try:
38
+ # For Zero-Shot, the architecture is always 'grounding_dino'
39
+ if task == "Zero-Shot Detection":
40
+ architecture = "grounding_dino"
41
+
42
  model_name = MODEL_REGISTRY[task][architecture]
43
+ print(f"Loading a fresh model instance: {model_name}")
44
+ model = ibbi.create_model(model_name, pretrained=True)
45
+ print("Model loaded successfully!")
46
+ return model
 
47
  except KeyError as e:
48
  raise gr.Error(f"Model lookup failed. Task: '{task}', Arch: '{architecture}'. Error: {e}")
49
  except Exception as e:
50
  raise gr.Error(f"Failed to load model. Please check the model name and your connection. Error: {e}")
51
 
52
  # --- Visualization and Drawing Functions ---
 
53
 
54
  def draw_yolo_predictions(image, results, font, color="red"):
55
  """Draws YOLO predictions on an image with a dynamically sized font."""
 
107
  buf.seek(0)
108
  return Image.open(buf)
109
 
110
+ # --- CORRECTED Main Processing Function ---
111
  def comprehensive_analysis(image, task, architecture, text_prompt, box_threshold, text_threshold):
112
+ """Performs the main analysis with corrected logic."""
113
  if image is None:
114
  raise gr.Error("Please upload an image first!")
115
 
116
  # Calculate a dynamic font size based on image width.
 
117
  dynamic_font_size = max(15, int(image.width * 0.04))
118
  try:
119
  font = ImageFont.truetype("arial.ttf", dynamic_font_size)
120
  except IOError:
121
  font = ImageFont.load_default(size=dynamic_font_size)
122
 
123
+ # Get a fresh model instance to avoid stateful errors
 
 
124
  model = get_model(task, architecture)
125
  outputs = {"annotated_image": None, "model_info": "", "classes_info": "", "embedding_plot": None}
126
 
 
130
  features = model.extract_features(image)
131
  outputs["model_info"] = f"Architecture: {architecture.upper()}\nTask: {task}\nDevice: {model.device}"
132
  outputs["classes_info"] = f"Classes: {model.get_classes()}"
133
+ else: # Zero-Shot Detection
134
  if not text_prompt:
135
  raise gr.Error("Please provide a text prompt for Zero-Shot Detection.")
136
 
 
140
  box_threshold=box_threshold,
141
  text_threshold=text_threshold
142
  )
 
143
  outputs["annotated_image"] = draw_dino_predictions(image, results, font=font)
144
  features = model.extract_features(image, text_prompt=text_prompt)
145
+ outputs["model_info"] = f"Architecture: GROUNDING_DINO\nTask: {task}\nDevice: {model.device}\nHF Model ID: {model.model.config._name_or_path}"
146
  outputs["classes_info"] = f"Prompt: '{text_prompt}'"
147
 
148
+ # Process features for visualization
149
  if isinstance(features, dict):
150
  outputs["embedding_plot"] = visualize_embedding(features.get('last_hidden_state'))
151
  else:
152
  outputs["embedding_plot"] = visualize_embedding(features)
153
 
154
+ # Correctly placed return statement ensures all outputs are always returned
155
  return outputs["annotated_image"], outputs["model_info"], outputs["classes_info"], outputs["embedding_plot"]
156
 
157
  # --- Gradio UI ---