new-pet-pet / app.py
mandali8686's picture
Update app.py
4e96e31
raw
history blame contribute delete
No virus
2.39 kB
import streamlit as st
import torch
from PIL import Image
from torchvision import transforms
from typing import List, Tuple
# Assuming your model and class names are set up correctly
pretrained_vit_path = 'pretrained_vit_model_full.pth'
class_names = ['Angry', 'Other', 'Sad', 'happy']
# Load your model
@st.experimental_singleton
def load_model():
model = torch.load(pretrained_vit_path, map_location=torch.device('cpu'))
model.eval()
return model
model = load_model()
# Function to apply transforms to the image
def transform_image(image, size=(224, 224)):
transform = transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor(),
# Add other transformations as needed, such as normalization
])
return transform(image).unsqueeze(0) # Add batch dimension
# Prediction function
def predict(model, image_tensor):
with torch.no_grad():
outputs = model(image_tensor)
_, predicted = torch.max(outputs, 1)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
top_prob, top_catid = torch.topk(probabilities, 1)
return class_names[predicted[0]], top_prob[0].item()
# Streamlit interface
st.title("Animal Facial Expression Recognition")
# Create two columns for the layout
col1, col2 = st.columns([1, 1])
# First column for the uploader
with col1:
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
# Second column for the prediction results
with col2:
if uploaded_file is not None:
# Display "Classifying..." text
st.write("Classifying...")
else:
# Show a message when no image is uploaded
st.write("Upload an image to see the classification result.")
# If an image has been uploaded, display it and run the prediction
if uploaded_file is not None:
# Display the uploaded image in the first column
with col1:
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption='Uploaded Image', use_column_width=True)
# Transform the image and make prediction in the second column
with col2:
image_tensor = transform_image(image)
predicted_class, probability = predict(model, image_tensor)
st.write(f'Predicted class: {predicted_class}')
st.write(f'Probability: {probability:.3f}')