sanpdy's picture
Using huggingface-hosted models
f9f1c14
# app.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.transforms as T
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
from ultralytics import YOLO
from transformers import ResNetModel
import cv2
from huggingface_hub import hf_hub_download
class FlakeLayerClassifier(nn.Module):
def __init__(self, num_materials, material_dim, num_classes=4, dropout_prob=0.1, freeze_cnn=False):
super().__init__()
self.cnn = ResNetModel.from_pretrained("microsoft/resnet-18")
if freeze_cnn:
for p in self.cnn.parameters():
p.requires_grad = False
img_feat_dim = self.cnn.config.hidden_sizes[-1]
self.material_embedding = nn.Embedding(num_materials, material_dim)
self.dropout = nn.Dropout(dropout_prob)
self.fc_img = nn.Sequential(
nn.Linear(img_feat_dim, img_feat_dim),
nn.ReLU(inplace=True),
self.dropout,
nn.Linear(img_feat_dim, num_classes)
)
combined_dim = img_feat_dim + material_dim
self.fc_comb = nn.Sequential(
nn.Linear(combined_dim, combined_dim),
nn.ReLU(inplace=True),
self.dropout,
nn.Linear(combined_dim, num_classes)
)
def forward(self, pixel_values, material=None):
outputs = self.cnn(pixel_values=pixel_values)
img_feats = outputs.pooler_output.view(outputs.pooler_output.size(0), -1)
if material is None:
return self.fc_img(img_feats)
mat_emb = self.material_embedding(material)
combined = torch.cat([img_feats, mat_emb], dim=1)
return self.fc_comb(combined)
def calibration(source_img, target_img):
source_lab = cv2.cvtColor(source_img, cv2.COLOR_BGR2LAB)
target_lab = cv2.cvtColor(target_img, cv2.COLOR_BGR2LAB)
for i in range(3):
src_mean, src_std = cv2.meanStdDev(source_lab[:, :, i])
tgt_mean, tgt_std = cv2.meanStdDev(target_lab[:, :, i])
target_lab[:, :, i] = (
(target_lab[:, :, i] - tgt_mean) * (src_std / tgt_std) + src_mean
).clip(0, 255)
corrected_img = cv2.cvtColor(target_lab, cv2.COLOR_LAB2BGR)
return corrected_img.astype(np.uint8)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load YOLO detector
#yolo = YOLO("/home/sankalp/flake_classification/models/best.pt")
#yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo11n_synthetic_runs/exp1/weights/best.pt")
#yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo_runs/yolo11l_flake_runs/weights/best.pt")
yolo_path = hf_hub_download(repo_id="sanpdy/yolo-flake-detector", filename="yolo-flake-detector-MSU.pt", token=False)
yolo = YOLO(yolo_path)
yolo.conf = 0.5
# Load classifier weights
classifier_path = hf_hub_download(
repo_id="sanpdy/flake-classifier",
filename="flake-classifier.pth",
token=False
)
ckpt = torch.load(classifier_path, map_location=device)
num_classes = len(ckpt["class_to_idx"])
classifier = FlakeLayerClassifier(
num_materials=num_classes,
material_dim=64,
num_classes=num_classes,
dropout_prob=0.1,
freeze_cnn=False
).to(device)
classifier.load_state_dict(ckpt["model_state_dict"])
classifier.eval()
# Image processing transforms
clf_tf = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
try:
FONT = ImageFont.truetype("arial.ttf", 20)
except IOError:
FONT = ImageFont.load_default()
# Inference + drawing
def detect_and_classify(image: Image.Image):
#image = calibration(
# np.array(Image.open("/home/sankalp/gradio_flake_app/quantum-flake-pipeline/template/image.png")),
#np.array(image.convert("RGB")),
#)
#image = Image.fromarray(image)
img_rgb = np.array(image.convert("RGB"))
img_bgr = img_rgb[:, :, ::-1]
results = yolo(img_bgr, device=str(device))
boxes = results[0].boxes.xyxy.cpu().numpy()
scores = results[0].boxes.conf.cpu().numpy()
draw = ImageDraw.Draw(image)
for (x1, y1, x2, y2), conf in zip(boxes, scores):
crop = image.crop((x1, y1, x2, y2))
inp = clf_tf(crop).unsqueeze(0).to(device) # (1,C,H,W)
with torch.no_grad():
logits = classifier(pixel_values=inp)
pred = logits.argmax(1).item()
prob = F.softmax(logits, dim=1)[0, pred].item()
label = f"Layer {pred+1} ({prob:.2f})"
# draw
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
draw.text((x1, max(0, y1-18)), label, fill="red", font=FONT)
return image
# Gradio UI
demo = gr.Interface(
fn=detect_and_classify,
inputs=gr.Image(type="pil", label="Upload Flake Image"),
outputs=gr.Image(type="pil", label="Annotated Output"),
title="Flake Detection + Layer Classification",
description="Upload an image β†’ YOLO finds flakes β†’ ResNet-18 head classifies their layer.",
)
if __name__ == "__main__":
demo.launch(share=True)