edit clas based atention
Browse files
app.py
CHANGED
|
@@ -24,12 +24,10 @@ MODEL_CONFIGS = {
|
|
| 24 |
"EfficientNet-B1": {"type": "efficientnet", "id": "efficientnet-b1"},
|
| 25 |
"ResNet-50": {"type": "timm", "id": "resnet50"},
|
| 26 |
"MobileNet-V2": {"type": "timm", "id": "mobilenetv2_100"},
|
| 27 |
-
"MobileNet-V3": {"type": "timm", "id": "mobilenetv3_small_100"},
|
| 28 |
"MaxViT-Tiny": {"type": "timm", "id": "maxvit_tiny_tf_224"},
|
| 29 |
"MobileViT-Small": {"type": "timm", "id": "mobilevit_s"},
|
| 30 |
"EdgeNeXt-Small": {"type": "timm", "id": "edgenext_small"},
|
| 31 |
-
"RegNetY-002": {"type": "timm", "id": "regnety_002"}
|
| 32 |
-
"SqueezeNet": {"type": "timm", "id": "squeezenet1_1"}
|
| 33 |
}
|
| 34 |
|
| 35 |
# ---------------------------
|
|
@@ -290,12 +288,13 @@ def get_class_specific_attention(image, model_name, class_query):
|
|
| 290 |
if image is None:
|
| 291 |
return None, "Please upload an image first"
|
| 292 |
|
| 293 |
-
if not class_query:
|
| 294 |
return None, "Please enter a class name"
|
| 295 |
|
| 296 |
# Find matching class
|
| 297 |
class_query_lower = class_query.lower().strip()
|
| 298 |
matching_idx = None
|
|
|
|
| 299 |
|
| 300 |
model, extractor = load_model(model_name)
|
| 301 |
|
|
@@ -304,10 +303,11 @@ def get_class_specific_attention(image, model_name, class_query):
|
|
| 304 |
for idx, label in model.config.id2label.items():
|
| 305 |
if class_query_lower in label.lower():
|
| 306 |
matching_idx = idx
|
|
|
|
| 307 |
break
|
| 308 |
|
| 309 |
if matching_idx is None:
|
| 310 |
-
return None, f"Class '{class_query}' not found in model labels"
|
| 311 |
|
| 312 |
# Get attention for this class
|
| 313 |
att_map = vit_attention_for_class(model, extractor, image, matching_idx)
|
|
@@ -317,10 +317,11 @@ def get_class_specific_attention(image, model_name, class_query):
|
|
| 317 |
for idx, label in enumerate(IMAGENET_LABELS):
|
| 318 |
if class_query_lower in label.lower():
|
| 319 |
matching_idx = idx
|
|
|
|
| 320 |
break
|
| 321 |
|
| 322 |
if matching_idx is None:
|
| 323 |
-
return None, f"Class '{class_query}' not found in ImageNet labels"
|
| 324 |
|
| 325 |
# Get Grad-CAM for this class
|
| 326 |
transform = T.Compose([
|
|
@@ -334,37 +335,43 @@ def get_class_specific_attention(image, model_name, class_query):
|
|
| 334 |
att_map = get_gradcam_for_class(model, x, matching_idx)
|
| 335 |
|
| 336 |
overlay = overlay_attention(image, att_map)
|
| 337 |
-
return overlay, f"Attention map for class: {
|
| 338 |
|
| 339 |
except Exception as e:
|
| 340 |
import traceback
|
| 341 |
-
|
|
|
|
|
|
|
| 342 |
|
| 343 |
|
| 344 |
# ---------------------------
|
| 345 |
# Get Class Suggestions
|
| 346 |
# ---------------------------
|
| 347 |
def get_class_suggestions(query, model_name):
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
|
| 370 |
# ---------------------------
|
|
@@ -372,7 +379,7 @@ def get_class_suggestions(query, model_name):
|
|
| 372 |
# ---------------------------
|
| 373 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 374 |
gr.Markdown("# π§ Enhanced Multi-Model Image Classifier")
|
| 375 |
-
gr.Markdown("### Features: Adversarial Examples | Class-Specific Attention |
|
| 376 |
|
| 377 |
with gr.Row():
|
| 378 |
with gr.Column(scale=1):
|
|
@@ -412,11 +419,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 412 |
info="Start typing to see suggestions"
|
| 413 |
)
|
| 414 |
class_suggestions = gr.Dropdown(
|
| 415 |
-
label="π‘ Suggestions",
|
| 416 |
choices=[],
|
| 417 |
-
interactive=True
|
|
|
|
| 418 |
)
|
| 419 |
-
class_button = gr.Button("π― Generate Class-Specific Attention")
|
| 420 |
|
| 421 |
with gr.Column(scale=2):
|
| 422 |
class_output_image = gr.Image(label="π Class-Specific Attention Map")
|
|
@@ -437,18 +445,21 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 437 |
outputs=[output_label, output_image, processed_image]
|
| 438 |
)
|
| 439 |
|
|
|
|
| 440 |
class_input.change(
|
| 441 |
get_class_suggestions,
|
| 442 |
inputs=[class_input, model_dropdown],
|
| 443 |
outputs=[class_suggestions]
|
| 444 |
)
|
| 445 |
|
| 446 |
-
|
|
|
|
| 447 |
lambda x: x,
|
| 448 |
inputs=[class_suggestions],
|
| 449 |
outputs=[class_input]
|
| 450 |
)
|
| 451 |
|
|
|
|
| 452 |
class_button.click(
|
| 453 |
get_class_specific_attention,
|
| 454 |
inputs=[input_image, model_dropdown, class_input],
|
|
|
|
| 24 |
"EfficientNet-B1": {"type": "efficientnet", "id": "efficientnet-b1"},
|
| 25 |
"ResNet-50": {"type": "timm", "id": "resnet50"},
|
| 26 |
"MobileNet-V2": {"type": "timm", "id": "mobilenetv2_100"},
|
|
|
|
| 27 |
"MaxViT-Tiny": {"type": "timm", "id": "maxvit_tiny_tf_224"},
|
| 28 |
"MobileViT-Small": {"type": "timm", "id": "mobilevit_s"},
|
| 29 |
"EdgeNeXt-Small": {"type": "timm", "id": "edgenext_small"},
|
| 30 |
+
"RegNetY-002": {"type": "timm", "id": "regnety_002"}
|
|
|
|
| 31 |
}
|
| 32 |
|
| 33 |
# ---------------------------
|
|
|
|
| 288 |
if image is None:
|
| 289 |
return None, "Please upload an image first"
|
| 290 |
|
| 291 |
+
if not class_query or class_query.strip() == "":
|
| 292 |
return None, "Please enter a class name"
|
| 293 |
|
| 294 |
# Find matching class
|
| 295 |
class_query_lower = class_query.lower().strip()
|
| 296 |
matching_idx = None
|
| 297 |
+
matched_label = None
|
| 298 |
|
| 299 |
model, extractor = load_model(model_name)
|
| 300 |
|
|
|
|
| 303 |
for idx, label in model.config.id2label.items():
|
| 304 |
if class_query_lower in label.lower():
|
| 305 |
matching_idx = idx
|
| 306 |
+
matched_label = label
|
| 307 |
break
|
| 308 |
|
| 309 |
if matching_idx is None:
|
| 310 |
+
return None, f"Class '{class_query}' not found in model labels. Try a different class name or check suggestions."
|
| 311 |
|
| 312 |
# Get attention for this class
|
| 313 |
att_map = vit_attention_for_class(model, extractor, image, matching_idx)
|
|
|
|
| 317 |
for idx, label in enumerate(IMAGENET_LABELS):
|
| 318 |
if class_query_lower in label.lower():
|
| 319 |
matching_idx = idx
|
| 320 |
+
matched_label = label
|
| 321 |
break
|
| 322 |
|
| 323 |
if matching_idx is None:
|
| 324 |
+
return None, f"Class '{class_query}' not found in ImageNet labels. Try a different class name or check suggestions."
|
| 325 |
|
| 326 |
# Get Grad-CAM for this class
|
| 327 |
transform = T.Compose([
|
|
|
|
| 335 |
att_map = get_gradcam_for_class(model, x, matching_idx)
|
| 336 |
|
| 337 |
overlay = overlay_attention(image, att_map)
|
| 338 |
+
return overlay, f"β Attention map generated for class: '{matched_label}'"
|
| 339 |
|
| 340 |
except Exception as e:
|
| 341 |
import traceback
|
| 342 |
+
error_trace = traceback.format_exc()
|
| 343 |
+
print(error_trace)
|
| 344 |
+
return None, f"Error generating attention map: {str(e)}"
|
| 345 |
|
| 346 |
|
| 347 |
# ---------------------------
|
| 348 |
# Get Class Suggestions
|
| 349 |
# ---------------------------
|
| 350 |
def get_class_suggestions(query, model_name):
|
| 351 |
+
try:
|
| 352 |
+
if not query or len(query.strip()) < 2:
|
| 353 |
+
return gr.Dropdown(choices=[], value=None)
|
| 354 |
+
|
| 355 |
+
query_lower = query.lower().strip()
|
| 356 |
+
suggestions = []
|
| 357 |
+
|
| 358 |
+
model, extractor = load_model(model_name)
|
| 359 |
+
|
| 360 |
+
if MODEL_CONFIGS[model_name]["type"] == "hf":
|
| 361 |
+
labels = list(model.config.id2label.values())
|
| 362 |
+
else:
|
| 363 |
+
labels = IMAGENET_LABELS
|
| 364 |
+
|
| 365 |
+
for label in labels:
|
| 366 |
+
if query_lower in label.lower():
|
| 367 |
+
suggestions.append(label)
|
| 368 |
+
if len(suggestions) >= 10: # Limit to 10 suggestions
|
| 369 |
+
break
|
| 370 |
+
|
| 371 |
+
return gr.Dropdown(choices=suggestions, value=None)
|
| 372 |
+
except Exception as e:
|
| 373 |
+
print(f"Error getting suggestions: {e}")
|
| 374 |
+
return gr.Dropdown(choices=[], value=None)
|
| 375 |
|
| 376 |
|
| 377 |
# ---------------------------
|
|
|
|
| 379 |
# ---------------------------
|
| 380 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 381 |
gr.Markdown("# π§ Enhanced Multi-Model Image Classifier")
|
| 382 |
+
gr.Markdown("### Features: Adversarial Examples | Class-Specific Attention | 13+ Models")
|
| 383 |
|
| 384 |
with gr.Row():
|
| 385 |
with gr.Column(scale=1):
|
|
|
|
| 419 |
info="Start typing to see suggestions"
|
| 420 |
)
|
| 421 |
class_suggestions = gr.Dropdown(
|
| 422 |
+
label="π‘ Suggestions (click to use)",
|
| 423 |
choices=[],
|
| 424 |
+
interactive=True,
|
| 425 |
+
allow_custom_value=False
|
| 426 |
)
|
| 427 |
+
class_button = gr.Button("π― Generate Class-Specific Attention", variant="primary")
|
| 428 |
|
| 429 |
with gr.Column(scale=2):
|
| 430 |
class_output_image = gr.Image(label="π Class-Specific Attention Map")
|
|
|
|
| 445 |
outputs=[output_label, output_image, processed_image]
|
| 446 |
)
|
| 447 |
|
| 448 |
+
# Update suggestions as user types
|
| 449 |
class_input.change(
|
| 450 |
get_class_suggestions,
|
| 451 |
inputs=[class_input, model_dropdown],
|
| 452 |
outputs=[class_suggestions]
|
| 453 |
)
|
| 454 |
|
| 455 |
+
# When user selects from suggestions, update the text input
|
| 456 |
+
class_suggestions.select(
|
| 457 |
lambda x: x,
|
| 458 |
inputs=[class_suggestions],
|
| 459 |
outputs=[class_input]
|
| 460 |
)
|
| 461 |
|
| 462 |
+
# Generate attention map
|
| 463 |
class_button.click(
|
| 464 |
get_class_specific_attention,
|
| 465 |
inputs=[input_image, model_dropdown, class_input],
|