File size: 2,663 Bytes
2a0bba9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import streamlit as st
import torch
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download
import json

# Title of the app
st.title("Image Classification with ResNet18")

# Sidebar for model information
st.sidebar.header("Model Information")
st.sidebar.write("This model is ResNet18 trained to classify images as hotdog or not hotdog.")

# Load the model from Hugging Face
@st.cache_resource
def load_model():
    model_path = hf_hub_download(repo_id="asidfactory/hotdognothotdog", filename="resnet_state_dict.pth")
    config_path = hf_hub_download(repo_id="asidfactory/hotdognothotdog", filename="config.json")
    
    # Load the model configuration
    with open(config_path, "r") as f:
        config = json.load(f)

    # Define the model architecture
    class ResNet18(torch.nn.Module):
        def __init__(self, num_classes):
            super(ResNet18, self).__init__()
            self.resnet18 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)
            self.resnet18.fc = torch.nn.Linear(self.resnet18.fc.in_features, num_classes)

        def forward(self, x):
            return self.resnet18(x)
    
    # Initialize and load weights
    model = ResNet18(num_classes=config["num_classes"])
    model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
    model.eval()  # Set to evaluation mode
    return model, config["classes"], config["normalize_mean"], config["normalize_std"]

model, classes, normalize_mean, normalize_std = load_model()

# Image preprocessing function
def preprocess_image(image, input_size, normalize_mean, normalize_std):
    transform = transforms.Compose([
        transforms.Resize((input_size[1], input_size[2])),
        transforms.ToTensor(),
        transforms.Normalize(mean=normalize_mean, std=normalize_std),
    ])
    return transform(image).unsqueeze(0)  # Add batch dimension

# File uploader for image input
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file:
    # Display the uploaded image
    image = Image.open(uploaded_file)
    st.image(image, caption="Uploaded Image", use_column_width=True)

    # Preprocess the image
    input_tensor = preprocess_image(image, [3, 256, 256], normalize_mean, normalize_std)

    # Make a prediction
    with torch.no_grad():
        outputs = model(input_tensor)
        _, predicted = torch.max(outputs, 1)
        predicted_class = classes[predicted.item()]

    # Display prediction
    st.write(f"Predicted class: **{predicted_class}**")