File size: 2,385 Bytes
7f4df39
a63b185
 
 
75b806b
7f4df39
75b806b
 
 
 
 
3abf34c
a63b185
75b806b
a63b185
 
 
 
 
75b806b
 
35f8363
75b806b
4e96e31
35f8363
75b806b
35f8363
 
a63b185
75b806b
 
 
 
 
 
 
 
 
 
35f8363
a63b185
75b806b
 
2603547
75b806b
 
 
35f8363
75b806b
 
 
 
 
 
 
 
35f8363
75b806b
 
 
 
 
 
 
 
 
 
 
 
 
35f8363
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import streamlit as st
import torch
from PIL import Image
from torchvision import transforms
from typing import List, Tuple

# Assuming your model and class names are set up correctly
pretrained_vit_path = 'pretrained_vit_model_full.pth'
class_names = ['Angry', 'Other', 'Sad', 'happy']

# Load your model
@st.experimental_singleton
def load_model():
    model = torch.load(pretrained_vit_path, map_location=torch.device('cpu'))
    model.eval()
    return model

model = load_model()

# Function to apply transforms to the image
def transform_image(image, size=(224, 224)):
    transform = transforms.Compose([
        transforms.Resize(size),  
        transforms.CenterCrop(size),  
        transforms.ToTensor(),
        # Add other transformations as needed, such as normalization
    ])
    return transform(image).unsqueeze(0)  # Add batch dimension

# Prediction function
def predict(model, image_tensor):
    with torch.no_grad():
        outputs = model(image_tensor)
        _, predicted = torch.max(outputs, 1)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        top_prob, top_catid = torch.topk(probabilities, 1)
    return class_names[predicted[0]], top_prob[0].item()

# Streamlit interface
st.title("Animal Facial Expression Recognition")

# Create two columns for the layout
col1, col2 = st.columns([1, 1])

# First column for the uploader
with col1:
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])

# Second column for the prediction results
with col2:
    if uploaded_file is not None:
        # Display "Classifying..." text
        st.write("Classifying...")
    else:
        # Show a message when no image is uploaded
        st.write("Upload an image to see the classification result.")

# If an image has been uploaded, display it and run the prediction
if uploaded_file is not None:
    # Display the uploaded image in the first column
    with col1:
        image = Image.open(uploaded_file).convert('RGB')
        st.image(image, caption='Uploaded Image', use_column_width=True)
        
    # Transform the image and make prediction in the second column
    with col2:
        image_tensor = transform_image(image)
        predicted_class, probability = predict(model, image_tensor)
        st.write(f'Predicted class: {predicted_class}')
        st.write(f'Probability: {probability:.3f}')