Spaces:
Sleeping
Sleeping
import random | |
import os | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
from groundingdino.util.inference import load_model as load_groundingdino_model | |
from groundingdino.util.inference import predict as grounding_dino_predict | |
import groundingdino.datasets.transforms as T | |
import torch | |
from torchvision.ops import box_convert | |
from torchvision.transforms.functional import to_tensor | |
from torchvision.transforms import GaussianBlur | |
import time | |
# ---------------------------- | |
# DINOv2 Classifier Imports | |
# ---------------------------- | |
import torch.nn as nn | |
from torchvision import transforms | |
import pandas as pd | |
from typing import List, Tuple | |
import copy | |
import matplotlib.pyplot as plt | |
# ---------------------------- | |
# DINOv2 Classifier Definitions | |
# ---------------------------- | |
# 1. PadToSquare Class | |
class PadToSquare: | |
""" | |
Pads an image to make it square by adding padding to the shorter side. | |
""" | |
def __init__(self, fill=0): | |
self.fill = fill | |
def __call__(self, img): | |
w, h = img.size | |
max_wh = max(w, h) | |
hp = (max_wh - w) // 2 | |
vp = (max_wh - h) // 2 | |
padding = (hp, vp, max_wh - w - hp, max_wh - h - vp) | |
return transforms.functional.pad(img, padding, fill=self.fill, padding_mode='constant') | |
# 2. DinoVisionTransformerClassifier Class (Modified to include entropy-based approach) | |
class DinoVisionTransformerClassifier(nn.Module): | |
""" | |
DINOv2 Vision Transformer-based classifier with entropy-based "Unknown" class handling. | |
""" | |
def __init__(self, num_classes, hidden_size=256, dropout_p=0.5, negative_slope=0.01): | |
super(DinoVisionTransformerClassifier, self).__init__() | |
# Load DINOv2 model from torch.hub | |
self.transformer = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14', pretrained=True) | |
self.transformer.norm = nn.Identity() # Remove existing normalization if necessary | |
# Batch Normalization after transformer | |
self.batch_norm1 = nn.BatchNorm1d(384) # 384 is the embedding size | |
# Classification head | |
self.classifier = nn.Sequential( | |
nn.Linear(384, hidden_size), | |
nn.BatchNorm1d(hidden_size), | |
nn.LeakyReLU(negative_slope=negative_slope, inplace=True), | |
nn.Dropout(p=dropout_p), | |
nn.Linear(hidden_size, num_classes) | |
) | |
# Initialize weights | |
self._initialize_weights() | |
def forward(self, x): | |
features = self.transformer(x) # Forward pass through the transformer | |
features = self.batch_norm1(features) # Apply Batch Normalization | |
logits = self.classifier(features) # Forward pass through the classification head | |
return logits, features # Return both logits and features | |
def _initialize_weights(self): | |
# Initialize weights of the classifier layers | |
for m in self.classifier.modules(): | |
if isinstance(m, nn.Linear): | |
nn.init.kaiming_normal_(m.weight, a=0.01, mode='fan_in', nonlinearity='leaky_relu') | |
if m.bias is not None: | |
nn.init.zeros_(m.bias) | |
elif isinstance(m, nn.BatchNorm1d): | |
nn.init.ones_(m.weight) | |
nn.init.zeros_(m.bias) | |
# 3. Model Loading Function (Updated for Entropy-Based Classifier) | |
def load_model(model_path, device): | |
""" | |
Loads the trained model and class information from the saved checkpoint. | |
Args: | |
model_path (str): Path to the saved .pth model file. | |
device (torch.device): Device to load the model onto. | |
Returns: | |
model (nn.Module): The loaded PyTorch model. | |
class_names (List[str]): List of class names. | |
""" | |
if not os.path.exists(model_path): | |
raise FileNotFoundError(f"Model file '{model_path}' does not exist.") | |
checkpoint = torch.load(model_path, map_location=device) | |
class_names = checkpoint['class_names'] | |
num_classes = len(class_names) | |
# Initialize the model architecture | |
model = DinoVisionTransformerClassifier(num_classes=num_classes) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.to(device) | |
model.eval() # Set to evaluation mode | |
return model, class_names | |
# 4. Image Preprocessing Function (Updated to accept PIL Image directly) | |
def preprocess_image_pil(pil_image: Image.Image, transform: transforms.Compose) -> torch.Tensor: | |
""" | |
Applies the transformation pipeline to a PIL image. | |
Args: | |
pil_image (PIL.Image.Image): The image to preprocess. | |
transform (transforms.Compose): The transformation pipeline. | |
Returns: | |
torch.Tensor: The preprocessed image tensor. | |
""" | |
return transform(pil_image) | |
# ---------------------------- | |
# Gradio App Definitions | |
# ---------------------------- | |
# Automatically set device based on availability | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {DEVICE}") | |
PROMPT = "bug" | |
# Define a custom transform for Gaussian blur (Unused in current context) | |
def gaussian_blur(x, p=0.5, kernel_size_min=3, kernel_size_max=20, sigma_min=0.1, sigma_max=3): | |
if x.ndim == 4: | |
for i in range(x.shape[0]): | |
if random.random() < p: | |
kernel_size = random.randrange(kernel_size_min, kernel_size_max + 1, 2) | |
sigma = random.uniform(sigma_min, sigma_max) | |
x[i] = GaussianBlur(kernel_size=kernel_size, sigma=sigma)(x[i]) | |
return x | |
# Custom Label Function (Unused in current context) | |
def custom_label_func(fpath): | |
# this directs the labels to be 2 levels up from the image folder | |
label = fpath.parents[2].name | |
return label | |
# Image loading function for GroundingDINO | |
def load_image(image_source): | |
transform = T.Compose( | |
[ | |
T.RandomResize([800], max_size=1333), | |
T.ToTensor(), | |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
image_source = image_source.convert("RGB") | |
image_transformed, _ = transform(image_source, None) | |
return image_transformed | |
# Load GroundingDINO object detection model | |
od_model = load_groundingdino_model( | |
model_checkpoint_path="groundingdino_swint_ogc.pth", | |
model_config_path="GroundingDINO_SwinT_OGC.cfg.py", | |
device=DEVICE) | |
print("Object detection model loaded") | |
# Load DINOv2 classifier model (Updated to use the entropy-based classifier) | |
# Update MODEL_PATH to the path where your DINOv2 model checkpoint is stored | |
MODEL_PATH = 'dinov2_classifier_with_vos_unsure.pth' # Updated model path | |
dinov2_model, class_names = load_model(MODEL_PATH, torch.device(DEVICE)) | |
print(f"DINOv2 Classification model loaded with {len(class_names)} classes.") | |
# Optionally, append "Unknown" to class names if needed | |
# Removed the line that appends "Unknown" as the model handles it via thresholding | |
# Replace specific class names if necessary | |
# Example: Replace "Scolotodes_schwarzi" with "Scolytodes_glaber" | |
target = "Scolotodes_schwarzi" | |
if target in class_names: | |
idx = class_names.index(target) | |
class_names[idx] = "Scolytodes_glaber" | |
print(f"Replaced '{target}' with 'Scolytodes_glaber' in class names.") | |
else: | |
print(f"'{target}' not found in class names. No replacement made.") | |
# Define the transformation pipeline for DINOv2 model | |
dinov2_transform = transforms.Compose([ | |
transforms.Resize(224), # Resize smaller edge to 224 | |
PadToSquare(), # Pad to make the image square | |
transforms.Resize((224, 224)), # Resize to 224x224 | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], # Normalize with ImageNet mean | |
[0.229, 0.224, 0.225]) # Normalize with ImageNet std | |
]) | |
# Object Detection Function | |
def detect_objects(og_image, model=od_model, prompt="bug . insect", device="cpu"): | |
TEXT_PROMPT = prompt | |
BOX_THRESHOLD = 0.15 # 35 Adjusted back to original value | |
TEXT_THRESHOLD = 0.15 # 25 Adjusted back to original value | |
DEVICE = device # cuda or cpu | |
# Convert numpy array to PIL Image if needed | |
if isinstance(og_image, np.ndarray): | |
og_image_obj = Image.fromarray(og_image) | |
else: | |
og_image_obj = og_image # Assuming og_image is already a PIL Image | |
# Transform the image | |
image_transformed = load_image(image_source = og_image_obj) | |
# Model prediction | |
boxes, logits, phrases = grounding_dino_predict( | |
model=model, | |
image=image_transformed, | |
caption=TEXT_PROMPT, | |
box_threshold=BOX_THRESHOLD, | |
text_threshold=TEXT_THRESHOLD, | |
device=DEVICE) | |
# Use og_image_obj directly for further processing | |
width, height = og_image_obj.size # Corrected to (width, height) | |
boxes_norm = boxes * torch.Tensor([width, height, width, height]) | |
xyxy = box_convert( | |
boxes=boxes_norm, | |
in_fmt="cxcywh", | |
out_fmt="xyxy").numpy() | |
img_lst = [] | |
for i in range(len(boxes_norm)): | |
crop_img = og_image_obj.crop((xyxy[i])) | |
img_lst.append(crop_img) | |
print(f"Detected {len(img_lst)} objects.") | |
return img_lst | |
# Inference/Class Prediction Function using the Entropy-Based DINOv2 Classifier | |
def classify_beetle(img: Image.Image, threshold=75.0): | |
""" | |
Classifies the input image using the DINOv2 classifier with entropy-based "Unknown" class. | |
Args: | |
img (PIL.Image.Image): The image to classify. | |
threshold (float): Confidence threshold to assign "Unknown". | |
Returns: | |
dict: The top 3 class labels with their corresponding confidence scores and "Unknown" if applicable. | |
""" | |
# Preprocess the image | |
input_tensor = preprocess_image_pil(img, dinov2_transform).unsqueeze(0).to(torch.device(DEVICE)) | |
print(f"Input tensor shape: {input_tensor.shape}") | |
with torch.no_grad(): | |
outputs, _ = dinov2_model(input_tensor) | |
print(f"Model outputs: {outputs}") | |
probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] # p(x) in [0,1] | |
print(f"Probabilities (0-1 scale): {probabilities}") | |
# Calculate entropy | |
# Adding a small epsilon to avoid log(0) | |
epsilon = 1e-12 | |
entropy = -np.sum(probabilities * np.log(probabilities + epsilon)) | |
# Maximum entropy for uniform distribution | |
max_entropy = -np.sum((1.0 / len(probabilities)) * np.log(1.0 / len(probabilities))) | |
normalized_entropy = entropy / max_entropy # Normalize between 0 and 1 | |
unknown_prob = normalized_entropy | |
print(f"Entropy: {entropy}, Normalized Entropy: {normalized_entropy}, Unknown Probability: {unknown_prob}") | |
# Convert probabilities to percentage for display | |
probabilities_percent = np.around(probabilities * 100, decimals=1) | |
print(f"Probabilities (Percentage): {probabilities_percent}") | |
# Get top 3 classes | |
top_indices = np.argsort(probabilities_percent)[-3:][::-1] # Indices of top 3 classes | |
top_probs = probabilities_percent[top_indices] | |
top_classes = [class_names[i] for i in top_indices] | |
# Initialize conf_dict with top 3 classes | |
conf_dict = {top_classes[i]: float(top_probs[i]) for i in range(len(top_classes))} | |
# Assign "Unknown" based on entropy and threshold | |
if top_probs[0] < threshold: | |
conf_dict["Unknown"] = float(np.around(unknown_prob, decimals=1)) | |
print(f"Conf_dict: {conf_dict}") | |
return conf_dict | |
# Main Prediction Function for Gradio | |
def predict_beetle(img): | |
print("Detecting objects in the image...") | |
start_time = time.perf_counter() # Start timing | |
# Detect objects in the image | |
image_lst = detect_objects(og_image=img, model=od_model, prompt=PROMPT, device=DEVICE) | |
print(f"Detected {len(image_lst)} objects.") | |
# Initialize lists to hold results | |
output_lst = [] | |
img_cnt = len(image_lst) | |
for i in range(img_cnt): | |
print(f"Classifying object {i+1}/{img_cnt}...") | |
conf_dict = classify_beetle(image_lst[i]) | |
output_lst.append([image_lst[i], conf_dict]) | |
print(f"Object {i+1} classified.") | |
end_time = time.perf_counter() | |
processing_time = end_time - start_time | |
print(f"Total processing duration: {processing_time:.2f} seconds") | |
return output_lst | |
# ---------------------------- | |
# Gradio Interface Setup | |
# ---------------------------- | |
sample_images_dir = "example_images" | |
# Sample images with labels | |
example_images = [ | |
os.path.join(sample_images_dir, "example1.jpg"), | |
os.path.join(sample_images_dir, "example2.jpg"), | |
os.path.join(sample_images_dir, "example3.jpg"), | |
os.path.join(sample_images_dir, "mixed.jpg") | |
] | |
# Corresponding labels for the example images | |
example_labels = ["Example Beetles 1", "Example Beetles 2", "Example Beetles 3", "Example Beetles 4"] | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1><center>Intelligent Bark Beetle Identifier (IBBI)</center></h1>") | |
with gr.Column(variant="panel"): | |
with gr.Row(variant="compact"): | |
inputs = gr.Image(label="Input Image") | |
# Add examples with labels | |
gr.Examples( | |
label="Select an example below if you have no images to upload.", | |
examples=example_images, | |
inputs=inputs, | |
examples_per_page=4, | |
example_labels=example_labels | |
) | |
btn = gr.Button("Classify", variant="primary") | |
# Set the gallery layout and height directly in the constructor | |
gallery = gr.Gallery(label="Classified Objects", show_label=True, elem_id="gallery", columns=4, height="auto") | |
# Define the output format for the gallery | |
def format_gallery(results): | |
formatted = [] | |
for img, conf in results: | |
# Create a label string from the confidence dictionary | |
label_str = ", ".join([f"{k}: {v:.1f}%" for k, v in conf.items()]) | |
# Append the image and label as a tuple | |
formatted.append((img, label_str)) | |
return formatted | |
# Modify the click event to format the gallery | |
btn.click( | |
lambda img: format_gallery(predict_beetle(img)), | |
inputs, | |
gallery | |
) | |
# Launch the Gradio app | |
demo.launch(share=True, inline=True, debug=True, show_error=True) | |