|
|
import torch |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
from ResNet_for_CC import CC_model |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
model = CC_model() |
|
|
|
|
|
|
|
|
model_path = 'CC_net.pt' |
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
if any(key.startswith('module.') for key in checkpoint.keys()): |
|
|
checkpoint = {k.replace('module.', ''): v for k, v in checkpoint.items()} |
|
|
model.load_state_dict(checkpoint) |
|
|
model.eval() |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
preprocess = transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
class_names = [ |
|
|
'T-Shirt', 'Shirt', 'Knitwear', 'Chiffon', 'Sweater', 'Hoodie', |
|
|
'Windbreaker', 'Jacket', 'Downcoat', 'Suit', 'Shawl', 'Dress', |
|
|
'Vest', 'Underwear' |
|
|
] |
|
|
|
|
|
def predict(image): |
|
|
|
|
|
img = Image.fromarray(image.astype('uint8'), 'RGB') |
|
|
img = preprocess(img).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
dr_feature, output_mean = model(img) |
|
|
|
|
|
|
|
|
_, predicted = torch.max(output_mean, 1) |
|
|
predicted_class = class_names[predicted.item()] |
|
|
|
|
|
|
|
|
return f"Predicted class: {predicted_class}" |
|
|
return f"Class number: {predicted.item()}" |
|
|
|
|
|
|
|
|
examples = [ |
|
|
["example_image(1).JPG"], |
|
|
["example_image(2).jpg"], |
|
|
["example_image(3).jpg"], |
|
|
["example_image(4).webp"], |
|
|
["example_image(5).webp"], |
|
|
["example_image(6).webp"] |
|
|
] |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Image(label="Upload Clothing Image"), |
|
|
outputs=gr.Textbox(label="Prediction"), |
|
|
title="Clothing Image Classifier", |
|
|
description="This model classifies clothing images using ResNet50. Try out different examples below for a quick demonstration!", |
|
|
examples=examples |
|
|
) |
|
|
|
|
|
interface.launch() |