IBBI / app.py
ChristopherMarais's picture
Update app.py
9b18251 verified
import random
import os
import numpy as np
import gradio as gr
from PIL import Image
from groundingdino.util.inference import load_model as load_groundingdino_model
from groundingdino.util.inference import predict as grounding_dino_predict
import groundingdino.datasets.transforms as T
import torch
from torchvision.ops import box_convert
from torchvision.transforms.functional import to_tensor
from torchvision.transforms import GaussianBlur
import time
# ----------------------------
# DINOv2 Classifier Imports
# ----------------------------
import torch.nn as nn
from torchvision import transforms
import pandas as pd
from typing import List, Tuple
import copy
import matplotlib.pyplot as plt
# ----------------------------
# DINOv2 Classifier Definitions
# ----------------------------
# 1. PadToSquare Class
class PadToSquare:
"""
Pads an image to make it square by adding padding to the shorter side.
"""
def __init__(self, fill=0):
self.fill = fill
def __call__(self, img):
w, h = img.size
max_wh = max(w, h)
hp = (max_wh - w) // 2
vp = (max_wh - h) // 2
padding = (hp, vp, max_wh - w - hp, max_wh - h - vp)
return transforms.functional.pad(img, padding, fill=self.fill, padding_mode='constant')
# 2. DinoVisionTransformerClassifier Class (Modified to include entropy-based approach)
class DinoVisionTransformerClassifier(nn.Module):
"""
DINOv2 Vision Transformer-based classifier with entropy-based "Unknown" class handling.
"""
def __init__(self, num_classes, hidden_size=256, dropout_p=0.5, negative_slope=0.01):
super(DinoVisionTransformerClassifier, self).__init__()
# Load DINOv2 model from torch.hub
self.transformer = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14', pretrained=True)
self.transformer.norm = nn.Identity() # Remove existing normalization if necessary
# Batch Normalization after transformer
self.batch_norm1 = nn.BatchNorm1d(384) # 384 is the embedding size
# Classification head
self.classifier = nn.Sequential(
nn.Linear(384, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.LeakyReLU(negative_slope=negative_slope, inplace=True),
nn.Dropout(p=dropout_p),
nn.Linear(hidden_size, num_classes)
)
# Initialize weights
self._initialize_weights()
def forward(self, x):
features = self.transformer(x) # Forward pass through the transformer
features = self.batch_norm1(features) # Apply Batch Normalization
logits = self.classifier(features) # Forward pass through the classification head
return logits, features # Return both logits and features
def _initialize_weights(self):
# Initialize weights of the classifier layers
for m in self.classifier.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, a=0.01, mode='fan_in', nonlinearity='leaky_relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm1d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
# 3. Model Loading Function (Updated for Entropy-Based Classifier)
def load_model(model_path, device):
"""
Loads the trained model and class information from the saved checkpoint.
Args:
model_path (str): Path to the saved .pth model file.
device (torch.device): Device to load the model onto.
Returns:
model (nn.Module): The loaded PyTorch model.
class_names (List[str]): List of class names.
"""
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file '{model_path}' does not exist.")
checkpoint = torch.load(model_path, map_location=device)
class_names = checkpoint['class_names']
num_classes = len(class_names)
# Initialize the model architecture
model = DinoVisionTransformerClassifier(num_classes=num_classes)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval() # Set to evaluation mode
return model, class_names
# 4. Image Preprocessing Function (Updated to accept PIL Image directly)
def preprocess_image_pil(pil_image: Image.Image, transform: transforms.Compose) -> torch.Tensor:
"""
Applies the transformation pipeline to a PIL image.
Args:
pil_image (PIL.Image.Image): The image to preprocess.
transform (transforms.Compose): The transformation pipeline.
Returns:
torch.Tensor: The preprocessed image tensor.
"""
return transform(pil_image)
# ----------------------------
# Gradio App Definitions
# ----------------------------
# Automatically set device based on availability
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
PROMPT = "bug"
# Define a custom transform for Gaussian blur (Unused in current context)
def gaussian_blur(x, p=0.5, kernel_size_min=3, kernel_size_max=20, sigma_min=0.1, sigma_max=3):
if x.ndim == 4:
for i in range(x.shape[0]):
if random.random() < p:
kernel_size = random.randrange(kernel_size_min, kernel_size_max + 1, 2)
sigma = random.uniform(sigma_min, sigma_max)
x[i] = GaussianBlur(kernel_size=kernel_size, sigma=sigma)(x[i])
return x
# Custom Label Function (Unused in current context)
def custom_label_func(fpath):
# this directs the labels to be 2 levels up from the image folder
label = fpath.parents[2].name
return label
# Image loading function for GroundingDINO
def load_image(image_source):
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_source = image_source.convert("RGB")
image_transformed, _ = transform(image_source, None)
return image_transformed
# Load GroundingDINO object detection model
od_model = load_groundingdino_model(
model_checkpoint_path="groundingdino_swint_ogc.pth",
model_config_path="GroundingDINO_SwinT_OGC.cfg.py",
device=DEVICE)
print("Object detection model loaded")
# Load DINOv2 classifier model (Updated to use the entropy-based classifier)
# Update MODEL_PATH to the path where your DINOv2 model checkpoint is stored
MODEL_PATH = 'dinov2_classifier_with_vos_unsure.pth' # Updated model path
dinov2_model, class_names = load_model(MODEL_PATH, torch.device(DEVICE))
print(f"DINOv2 Classification model loaded with {len(class_names)} classes.")
# Optionally, append "Unknown" to class names if needed
# Removed the line that appends "Unknown" as the model handles it via thresholding
# Replace specific class names if necessary
# Example: Replace "Scolotodes_schwarzi" with "Scolytodes_glaber"
target = "Scolotodes_schwarzi"
if target in class_names:
idx = class_names.index(target)
class_names[idx] = "Scolytodes_glaber"
print(f"Replaced '{target}' with 'Scolytodes_glaber' in class names.")
else:
print(f"'{target}' not found in class names. No replacement made.")
# Define the transformation pipeline for DINOv2 model
dinov2_transform = transforms.Compose([
transforms.Resize(224), # Resize smaller edge to 224
PadToSquare(), # Pad to make the image square
transforms.Resize((224, 224)), # Resize to 224x224
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], # Normalize with ImageNet mean
[0.229, 0.224, 0.225]) # Normalize with ImageNet std
])
# Object Detection Function
def detect_objects(og_image, model=od_model, prompt="bug . insect", device="cpu"):
TEXT_PROMPT = prompt
BOX_THRESHOLD = 0.15 # 35 Adjusted back to original value
TEXT_THRESHOLD = 0.15 # 25 Adjusted back to original value
DEVICE = device # cuda or cpu
# Convert numpy array to PIL Image if needed
if isinstance(og_image, np.ndarray):
og_image_obj = Image.fromarray(og_image)
else:
og_image_obj = og_image # Assuming og_image is already a PIL Image
# Transform the image
image_transformed = load_image(image_source = og_image_obj)
# Model prediction
boxes, logits, phrases = grounding_dino_predict(
model=model,
image=image_transformed,
caption=TEXT_PROMPT,
box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD,
device=DEVICE)
# Use og_image_obj directly for further processing
width, height = og_image_obj.size # Corrected to (width, height)
boxes_norm = boxes * torch.Tensor([width, height, width, height])
xyxy = box_convert(
boxes=boxes_norm,
in_fmt="cxcywh",
out_fmt="xyxy").numpy()
img_lst = []
for i in range(len(boxes_norm)):
crop_img = og_image_obj.crop((xyxy[i]))
img_lst.append(crop_img)
print(f"Detected {len(img_lst)} objects.")
return img_lst
# Inference/Class Prediction Function using the Entropy-Based DINOv2 Classifier
def classify_beetle(img: Image.Image, threshold=75.0):
"""
Classifies the input image using the DINOv2 classifier with entropy-based "Unknown" class.
Args:
img (PIL.Image.Image): The image to classify.
threshold (float): Confidence threshold to assign "Unknown".
Returns:
dict: The top 3 class labels with their corresponding confidence scores and "Unknown" if applicable.
"""
# Preprocess the image
input_tensor = preprocess_image_pil(img, dinov2_transform).unsqueeze(0).to(torch.device(DEVICE))
print(f"Input tensor shape: {input_tensor.shape}")
with torch.no_grad():
outputs, _ = dinov2_model(input_tensor)
print(f"Model outputs: {outputs}")
probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] # p(x) in [0,1]
print(f"Probabilities (0-1 scale): {probabilities}")
# Calculate entropy
# Adding a small epsilon to avoid log(0)
epsilon = 1e-12
entropy = -np.sum(probabilities * np.log(probabilities + epsilon))
# Maximum entropy for uniform distribution
max_entropy = -np.sum((1.0 / len(probabilities)) * np.log(1.0 / len(probabilities)))
normalized_entropy = entropy / max_entropy # Normalize between 0 and 1
unknown_prob = normalized_entropy
print(f"Entropy: {entropy}, Normalized Entropy: {normalized_entropy}, Unknown Probability: {unknown_prob}")
# Convert probabilities to percentage for display
probabilities_percent = np.around(probabilities * 100, decimals=1)
print(f"Probabilities (Percentage): {probabilities_percent}")
# Get top 3 classes
top_indices = np.argsort(probabilities_percent)[-3:][::-1] # Indices of top 3 classes
top_probs = probabilities_percent[top_indices]
top_classes = [class_names[i] for i in top_indices]
# Initialize conf_dict with top 3 classes
conf_dict = {top_classes[i]: float(top_probs[i]) for i in range(len(top_classes))}
# Assign "Unknown" based on entropy and threshold
if top_probs[0] < threshold:
conf_dict["Unknown"] = float(np.around(unknown_prob, decimals=1))
print(f"Conf_dict: {conf_dict}")
return conf_dict
# Main Prediction Function for Gradio
def predict_beetle(img):
print("Detecting objects in the image...")
start_time = time.perf_counter() # Start timing
# Detect objects in the image
image_lst = detect_objects(og_image=img, model=od_model, prompt=PROMPT, device=DEVICE)
print(f"Detected {len(image_lst)} objects.")
# Initialize lists to hold results
output_lst = []
img_cnt = len(image_lst)
for i in range(img_cnt):
print(f"Classifying object {i+1}/{img_cnt}...")
conf_dict = classify_beetle(image_lst[i])
output_lst.append([image_lst[i], conf_dict])
print(f"Object {i+1} classified.")
end_time = time.perf_counter()
processing_time = end_time - start_time
print(f"Total processing duration: {processing_time:.2f} seconds")
return output_lst
# ----------------------------
# Gradio Interface Setup
# ----------------------------
sample_images_dir = "example_images"
# Sample images with labels
example_images = [
os.path.join(sample_images_dir, "example1.jpg"),
os.path.join(sample_images_dir, "example2.jpg"),
os.path.join(sample_images_dir, "example3.jpg"),
os.path.join(sample_images_dir, "mixed.jpg")
]
# Corresponding labels for the example images
example_labels = ["Example Beetles 1", "Example Beetles 2", "Example Beetles 3", "Example Beetles 4"]
with gr.Blocks() as demo:
gr.Markdown("<h1><center>Intelligent Bark Beetle Identifier (IBBI)</center></h1>")
with gr.Column(variant="panel"):
with gr.Row(variant="compact"):
inputs = gr.Image(label="Input Image")
# Add examples with labels
gr.Examples(
label="Select an example below if you have no images to upload.",
examples=example_images,
inputs=inputs,
examples_per_page=4,
example_labels=example_labels
)
btn = gr.Button("Classify", variant="primary")
# Set the gallery layout and height directly in the constructor
gallery = gr.Gallery(label="Classified Objects", show_label=True, elem_id="gallery", columns=4, height="auto")
# Define the output format for the gallery
def format_gallery(results):
formatted = []
for img, conf in results:
# Create a label string from the confidence dictionary
label_str = ", ".join([f"{k}: {v:.1f}%" for k, v in conf.items()])
# Append the image and label as a tuple
formatted.append((img, label_str))
return formatted
# Modify the click event to format the gallery
btn.click(
lambda img: format_gallery(predict_beetle(img)),
inputs,
gallery
)
# Launch the Gradio app
demo.launch(share=True, inline=True, debug=True, show_error=True)