Spaces:
Sleeping
Sleeping
File size: 2,385 Bytes
7f4df39 a63b185 75b806b 7f4df39 75b806b 3abf34c a63b185 75b806b a63b185 75b806b 35f8363 75b806b 4e96e31 35f8363 75b806b 35f8363 a63b185 75b806b 35f8363 a63b185 75b806b 2603547 75b806b 35f8363 75b806b 35f8363 75b806b 35f8363 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import streamlit as st
import torch
from PIL import Image
from torchvision import transforms
from typing import List, Tuple
# Assuming your model and class names are set up correctly
pretrained_vit_path = 'pretrained_vit_model_full.pth'
class_names = ['Angry', 'Other', 'Sad', 'happy']
# Load your model
@st.experimental_singleton
def load_model():
model = torch.load(pretrained_vit_path, map_location=torch.device('cpu'))
model.eval()
return model
model = load_model()
# Function to apply transforms to the image
def transform_image(image, size=(224, 224)):
transform = transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor(),
# Add other transformations as needed, such as normalization
])
return transform(image).unsqueeze(0) # Add batch dimension
# Prediction function
def predict(model, image_tensor):
with torch.no_grad():
outputs = model(image_tensor)
_, predicted = torch.max(outputs, 1)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
top_prob, top_catid = torch.topk(probabilities, 1)
return class_names[predicted[0]], top_prob[0].item()
# Streamlit interface
st.title("Animal Facial Expression Recognition")
# Create two columns for the layout
col1, col2 = st.columns([1, 1])
# First column for the uploader
with col1:
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
# Second column for the prediction results
with col2:
if uploaded_file is not None:
# Display "Classifying..." text
st.write("Classifying...")
else:
# Show a message when no image is uploaded
st.write("Upload an image to see the classification result.")
# If an image has been uploaded, display it and run the prediction
if uploaded_file is not None:
# Display the uploaded image in the first column
with col1:
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption='Uploaded Image', use_column_width=True)
# Transform the image and make prediction in the second column
with col2:
image_tensor = transform_image(image)
predicted_class, probability = predict(model, image_tensor)
st.write(f'Predicted class: {predicted_class}')
st.write(f'Probability: {probability:.3f}')
|