venkyvicky's picture
Update app.py
aae347b verified
import gradio as gr
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
from PIL import Image
from ResNet_for_CC import CC_model # Ensure correct model import
# Set device (CPU/GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the trained CC_model
model_path = "CC_net.pt" # Update path if necessary
model = CC_model(num_classes1=14)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
# Define Clothing1M Class Labels
class_labels = [
"T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie",
"Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress",
"Vest", "Underwear"
]
# Preprocess images
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Inference function with confidence scores
def classify_image(image):
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
predicted_idx = probabilities.argmax()
predicted_label = class_labels[predicted_idx]
confidence = probabilities[predicted_idx]
# Prepare a readable confidence interval
confidence_pct = round(confidence * 100, 2)
result = f"Predicted Class: {predicted_label}\nConfidence: {confidence_pct}%"
return result
# Example images for the Gradio Interface (upload these images to your Hugging Face Space)
example_images = [
"img1.png",
"img2.png",
"img3.png",
"img4.png",
"img5.png"
]
# Gradio Interface including confidence intervals
interface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(label="Prediction and Confidence"),
title="Clothing1M Image Classifier with Confidence Interval",
description="Upload an image or select from examples to classify it and view the confidence percentage.",
examples=example_images,
cache_examples=False
)
# Launch the interface
if __name__ == "__main__":
interface.launch()