Godreign commited on
Commit
4c73f8b
Β·
verified Β·
1 Parent(s): 0d447bd

edit clas based atention

Browse files
Files changed (1) hide show
  1. app.py +44 -33
app.py CHANGED
@@ -24,12 +24,10 @@ MODEL_CONFIGS = {
24
  "EfficientNet-B1": {"type": "efficientnet", "id": "efficientnet-b1"},
25
  "ResNet-50": {"type": "timm", "id": "resnet50"},
26
  "MobileNet-V2": {"type": "timm", "id": "mobilenetv2_100"},
27
- "MobileNet-V3": {"type": "timm", "id": "mobilenetv3_small_100"},
28
  "MaxViT-Tiny": {"type": "timm", "id": "maxvit_tiny_tf_224"},
29
  "MobileViT-Small": {"type": "timm", "id": "mobilevit_s"},
30
  "EdgeNeXt-Small": {"type": "timm", "id": "edgenext_small"},
31
- "RegNetY-002": {"type": "timm", "id": "regnety_002"},
32
- "SqueezeNet": {"type": "timm", "id": "squeezenet1_1"}
33
  }
34
 
35
  # ---------------------------
@@ -290,12 +288,13 @@ def get_class_specific_attention(image, model_name, class_query):
290
  if image is None:
291
  return None, "Please upload an image first"
292
 
293
- if not class_query:
294
  return None, "Please enter a class name"
295
 
296
  # Find matching class
297
  class_query_lower = class_query.lower().strip()
298
  matching_idx = None
 
299
 
300
  model, extractor = load_model(model_name)
301
 
@@ -304,10 +303,11 @@ def get_class_specific_attention(image, model_name, class_query):
304
  for idx, label in model.config.id2label.items():
305
  if class_query_lower in label.lower():
306
  matching_idx = idx
 
307
  break
308
 
309
  if matching_idx is None:
310
- return None, f"Class '{class_query}' not found in model labels"
311
 
312
  # Get attention for this class
313
  att_map = vit_attention_for_class(model, extractor, image, matching_idx)
@@ -317,10 +317,11 @@ def get_class_specific_attention(image, model_name, class_query):
317
  for idx, label in enumerate(IMAGENET_LABELS):
318
  if class_query_lower in label.lower():
319
  matching_idx = idx
 
320
  break
321
 
322
  if matching_idx is None:
323
- return None, f"Class '{class_query}' not found in ImageNet labels"
324
 
325
  # Get Grad-CAM for this class
326
  transform = T.Compose([
@@ -334,37 +335,43 @@ def get_class_specific_attention(image, model_name, class_query):
334
  att_map = get_gradcam_for_class(model, x, matching_idx)
335
 
336
  overlay = overlay_attention(image, att_map)
337
- return overlay, f"Attention map for class: {class_query}"
338
 
339
  except Exception as e:
340
  import traceback
341
- return None, f"Error: {str(e)}\n{traceback.format_exc()}"
 
 
342
 
343
 
344
  # ---------------------------
345
  # Get Class Suggestions
346
  # ---------------------------
347
  def get_class_suggestions(query, model_name):
348
- if not query or len(query) < 2:
349
- return []
350
-
351
- query_lower = query.lower()
352
- suggestions = []
353
-
354
- model, extractor = load_model(model_name)
355
-
356
- if MODEL_CONFIGS[model_name]["type"] == "hf":
357
- labels = list(model.config.id2label.values())
358
- else:
359
- labels = IMAGENET_LABELS
360
-
361
- for label in labels:
362
- if query_lower in label.lower():
363
- suggestions.append(label)
364
- if len(suggestions) >= 10: # Limit to 10 suggestions
365
- break
366
-
367
- return suggestions
 
 
 
 
368
 
369
 
370
  # ---------------------------
@@ -372,7 +379,7 @@ def get_class_suggestions(query, model_name):
372
  # ---------------------------
373
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
374
  gr.Markdown("# 🧠 Enhanced Multi-Model Image Classifier")
375
- gr.Markdown("### Features: Adversarial Examples | Class-Specific Attention | 15+ Models")
376
 
377
  with gr.Row():
378
  with gr.Column(scale=1):
@@ -412,11 +419,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
412
  info="Start typing to see suggestions"
413
  )
414
  class_suggestions = gr.Dropdown(
415
- label="πŸ’‘ Suggestions",
416
  choices=[],
417
- interactive=True
 
418
  )
419
- class_button = gr.Button("🎯 Generate Class-Specific Attention")
420
 
421
  with gr.Column(scale=2):
422
  class_output_image = gr.Image(label="πŸ” Class-Specific Attention Map")
@@ -437,18 +445,21 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
437
  outputs=[output_label, output_image, processed_image]
438
  )
439
 
 
440
  class_input.change(
441
  get_class_suggestions,
442
  inputs=[class_input, model_dropdown],
443
  outputs=[class_suggestions]
444
  )
445
 
446
- class_suggestions.change(
 
447
  lambda x: x,
448
  inputs=[class_suggestions],
449
  outputs=[class_input]
450
  )
451
 
 
452
  class_button.click(
453
  get_class_specific_attention,
454
  inputs=[input_image, model_dropdown, class_input],
 
24
  "EfficientNet-B1": {"type": "efficientnet", "id": "efficientnet-b1"},
25
  "ResNet-50": {"type": "timm", "id": "resnet50"},
26
  "MobileNet-V2": {"type": "timm", "id": "mobilenetv2_100"},
 
27
  "MaxViT-Tiny": {"type": "timm", "id": "maxvit_tiny_tf_224"},
28
  "MobileViT-Small": {"type": "timm", "id": "mobilevit_s"},
29
  "EdgeNeXt-Small": {"type": "timm", "id": "edgenext_small"},
30
+ "RegNetY-002": {"type": "timm", "id": "regnety_002"}
 
31
  }
32
 
33
  # ---------------------------
 
288
  if image is None:
289
  return None, "Please upload an image first"
290
 
291
+ if not class_query or class_query.strip() == "":
292
  return None, "Please enter a class name"
293
 
294
  # Find matching class
295
  class_query_lower = class_query.lower().strip()
296
  matching_idx = None
297
+ matched_label = None
298
 
299
  model, extractor = load_model(model_name)
300
 
 
303
  for idx, label in model.config.id2label.items():
304
  if class_query_lower in label.lower():
305
  matching_idx = idx
306
+ matched_label = label
307
  break
308
 
309
  if matching_idx is None:
310
+ return None, f"Class '{class_query}' not found in model labels. Try a different class name or check suggestions."
311
 
312
  # Get attention for this class
313
  att_map = vit_attention_for_class(model, extractor, image, matching_idx)
 
317
  for idx, label in enumerate(IMAGENET_LABELS):
318
  if class_query_lower in label.lower():
319
  matching_idx = idx
320
+ matched_label = label
321
  break
322
 
323
  if matching_idx is None:
324
+ return None, f"Class '{class_query}' not found in ImageNet labels. Try a different class name or check suggestions."
325
 
326
  # Get Grad-CAM for this class
327
  transform = T.Compose([
 
335
  att_map = get_gradcam_for_class(model, x, matching_idx)
336
 
337
  overlay = overlay_attention(image, att_map)
338
+ return overlay, f"βœ“ Attention map generated for class: '{matched_label}'"
339
 
340
  except Exception as e:
341
  import traceback
342
+ error_trace = traceback.format_exc()
343
+ print(error_trace)
344
+ return None, f"Error generating attention map: {str(e)}"
345
 
346
 
347
  # ---------------------------
348
  # Get Class Suggestions
349
  # ---------------------------
350
  def get_class_suggestions(query, model_name):
351
+ try:
352
+ if not query or len(query.strip()) < 2:
353
+ return gr.Dropdown(choices=[], value=None)
354
+
355
+ query_lower = query.lower().strip()
356
+ suggestions = []
357
+
358
+ model, extractor = load_model(model_name)
359
+
360
+ if MODEL_CONFIGS[model_name]["type"] == "hf":
361
+ labels = list(model.config.id2label.values())
362
+ else:
363
+ labels = IMAGENET_LABELS
364
+
365
+ for label in labels:
366
+ if query_lower in label.lower():
367
+ suggestions.append(label)
368
+ if len(suggestions) >= 10: # Limit to 10 suggestions
369
+ break
370
+
371
+ return gr.Dropdown(choices=suggestions, value=None)
372
+ except Exception as e:
373
+ print(f"Error getting suggestions: {e}")
374
+ return gr.Dropdown(choices=[], value=None)
375
 
376
 
377
  # ---------------------------
 
379
  # ---------------------------
380
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
381
  gr.Markdown("# 🧠 Enhanced Multi-Model Image Classifier")
382
+ gr.Markdown("### Features: Adversarial Examples | Class-Specific Attention | 13+ Models")
383
 
384
  with gr.Row():
385
  with gr.Column(scale=1):
 
419
  info="Start typing to see suggestions"
420
  )
421
  class_suggestions = gr.Dropdown(
422
+ label="πŸ’‘ Suggestions (click to use)",
423
  choices=[],
424
+ interactive=True,
425
+ allow_custom_value=False
426
  )
427
+ class_button = gr.Button("🎯 Generate Class-Specific Attention", variant="primary")
428
 
429
  with gr.Column(scale=2):
430
  class_output_image = gr.Image(label="πŸ” Class-Specific Attention Map")
 
445
  outputs=[output_label, output_image, processed_image]
446
  )
447
 
448
+ # Update suggestions as user types
449
  class_input.change(
450
  get_class_suggestions,
451
  inputs=[class_input, model_dropdown],
452
  outputs=[class_suggestions]
453
  )
454
 
455
+ # When user selects from suggestions, update the text input
456
+ class_suggestions.select(
457
  lambda x: x,
458
  inputs=[class_suggestions],
459
  outputs=[class_input]
460
  )
461
 
462
+ # Generate attention map
463
  class_button.click(
464
  get_class_specific_attention,
465
  inputs=[input_image, model_dropdown, class_input],