File size: 2,193 Bytes
e245366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8aa5617
 
e245366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9065d30
 
e245366
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from torchvision import transforms
from PIL import Image
import gradio as gr
from ResNet_for_CC import CC_model

# Initialize the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CC_model()

# Load the pre-trained weights, adjusting for DataParallel if necessary
model_path = 'CC_net.pt'
checkpoint = torch.load(model_path, map_location=device)
if any(key.startswith('module.') for key in checkpoint.keys()):
    checkpoint = {k.replace('module.', ''): v for k, v in checkpoint.items()}
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

# Image preprocessing
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define class names from category_names_eng.txt
class_names = [
    'T-Shirt', 'Shirt', 'Knitwear', 'Chiffon', 'Sweater', 'Hoodie', 
    'Windbreaker', 'Jacket', 'Downcoat', 'Suit', 'Shawl', 'Dress', 
    'Vest', 'Underwear'
]

def predict(image):
    # Convert Gradio Image to PIL and preprocess
    img = Image.fromarray(image.astype('uint8'), 'RGB')
    img = preprocess(img).unsqueeze(0).to(device)
    
    # Generate predictions
    with torch.no_grad():
        dr_feature, output_mean = model(img)
    
    # Get the predicted class
    _, predicted = torch.max(output_mean, 1)
    predicted_class = class_names[predicted.item()]
    
    # Format output
    return f"Predicted class: {predicted_class}"
    return f"Class number: {predicted.item()}"

# Example images from Hugging Face
examples = [
    ["example_image(1).JPG"],
    ["example_image(2).jpg"],
    ["example_image(3).jpg"],
    ["example_image(4).webp"],
    ["example_image(5).webp"],
    ["example_image(6).webp"]
]

# Gradio Interface
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(label="Upload Clothing Image"),
    outputs=gr.Textbox(label="Prediction"),
    title="Clothing Image Classifier",
    description="This model classifies clothing images using ResNet50. Try out different examples below for a quick demonstration!",
    examples=examples
)

interface.launch()