JohnJoelMota's picture
updated output format
8aa5617 verified
import torch
from torchvision import transforms
from PIL import Image
import gradio as gr
from ResNet_for_CC import CC_model
# Initialize the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CC_model()
# Load the pre-trained weights, adjusting for DataParallel if necessary
model_path = 'CC_net.pt'
checkpoint = torch.load(model_path, map_location=device)
if any(key.startswith('module.') for key in checkpoint.keys()):
checkpoint = {k.replace('module.', ''): v for k, v in checkpoint.items()}
model.load_state_dict(checkpoint)
model.eval()
model.to(device)
# Image preprocessing
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Define class names from category_names_eng.txt
class_names = [
'T-Shirt', 'Shirt', 'Knitwear', 'Chiffon', 'Sweater', 'Hoodie',
'Windbreaker', 'Jacket', 'Downcoat', 'Suit', 'Shawl', 'Dress',
'Vest', 'Underwear'
]
def predict(image):
# Convert Gradio Image to PIL and preprocess
img = Image.fromarray(image.astype('uint8'), 'RGB')
img = preprocess(img).unsqueeze(0).to(device)
# Generate predictions
with torch.no_grad():
dr_feature, output_mean = model(img)
# Get the predicted class
_, predicted = torch.max(output_mean, 1)
predicted_class = class_names[predicted.item()]
# Format output
return f"Predicted class: {predicted_class}"
return f"Class number: {predicted.item()}"
# Example images from Hugging Face
examples = [
["example_image(1).JPG"],
["example_image(2).jpg"],
["example_image(3).jpg"],
["example_image(4).webp"],
["example_image(5).webp"],
["example_image(6).webp"]
]
# Gradio Interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(label="Upload Clothing Image"),
outputs=gr.Textbox(label="Prediction"),
title="Clothing Image Classifier",
description="This model classifies clothing images using ResNet50. Try out different examples below for a quick demonstration!",
examples=examples
)
interface.launch()