File size: 3,135 Bytes
5303063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222cfe5
 
 
5303063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222cfe5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import torch
import pickle
import joblib
import torch.nn.functional as F
from PIL import Image
import gradio as gr
from transformers import AutoModelForImageClassification
from torch import nn
from torchvision import transforms
from huggingface_hub import hf_hub_download

# Paths in Hugging Face model repository
MODEL_PATH = "DeiT_Model_Parameter.pth"
ENCODER_PATH = "label_encoder.pkl"

# Ensure device is set
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_label_encoder():
    # Load label encoder from Hugging Face repository
    label_encoder_path = hf_hub_download(repo_id="bobs24/DeiT-Classification-Apparel", filename=ENCODER_PATH)
    label_encoder = joblib.load(label_encoder_path)
    return label_encoder

# Define the model class
class CustomModel(nn.Module):
    def __init__(self, num_classes):
        super(CustomModel, self).__init__()
        self.base_model = AutoModelForImageClassification.from_pretrained(
            "facebook/deit-base-patch16-224",
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        )

    def forward(self, x):
        return self.base_model(x).logits

def load_model():
    # Load the model from Hugging Face repository
    model_path = hf_hub_download(repo_id="bobs24/DeiT-Classification-Apparel", filename=MODEL_PATH)
    label_encoder = load_label_encoder()
    model = CustomModel(num_classes=len(label_encoder.classes_)).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.device = device
    model.eval()
    
    return model, label_encoder

# Load the model and label encoder
model, label_encoder = load_model()

# Preprocessing as per your training setup
preprocess = transforms.Compose([
    transforms.Resize(256),  # Resize to 256x256 (a bit larger than 224)
    transforms.CenterCrop(224),  # Crop the center to 224x224
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))  # Normalize as per DeiT
])

# Function to perform predictions and show probabilities
def predict(image):
    if image is None:  # Check if no image was provided
        return "Please insert photo"
    
    # Apply preprocessing to the input image
    image = Image.fromarray(image).convert("RGB")
    input_tensor = preprocess(image).unsqueeze(0).to(device)
    
    # Perform inference
    with torch.no_grad():
        output = model(input_tensor)
        
        # Apply softmax to get probabilities
        probabilities = F.softmax(output, dim=1)
        
        # Get the predicted label and confidence
        predicted_label = torch.argmax(probabilities, dim=1).item()
        confidence = probabilities[0, predicted_label].item()
        
        # Get the class name using label encoder
        class_name = label_encoder.inverse_transform([predicted_label])[0]
    
    return f"Predicted class: {class_name}, Confidence: {confidence:.4f}"

# Create Gradio interface
iface = gr.Interface(fn=predict, inputs=gr.Image(type="numpy"), outputs="text", live=True)

# Launch the interface
iface.launch()