import streamlit as st import torch from PIL import Image from torchvision import transforms from typing import List, Tuple # Assuming your model and class names are set up correctly pretrained_vit_path = 'pretrained_vit_model_full.pth' class_names = ['Angry', 'Other', 'Sad', 'happy'] # Load your model @st.experimental_singleton def load_model(): model = torch.load(pretrained_vit_path, map_location=torch.device('cpu')) model.eval() return model model = load_model() # Function to apply transforms to the image def transform_image(image, size=(224, 224)): transform = transforms.Compose([ transforms.Resize(size), transforms.CenterCrop(size), transforms.ToTensor(), # Add other transformations as needed, such as normalization ]) return transform(image).unsqueeze(0) # Add batch dimension # Prediction function def predict(model, image_tensor): with torch.no_grad(): outputs = model(image_tensor) _, predicted = torch.max(outputs, 1) probabilities = torch.nn.functional.softmax(outputs, dim=1) top_prob, top_catid = torch.topk(probabilities, 1) return class_names[predicted[0]], top_prob[0].item() # Streamlit interface st.title("Animal Facial Expression Recognition") # Create two columns for the layout col1, col2 = st.columns([1, 1]) # First column for the uploader with col1: uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) # Second column for the prediction results with col2: if uploaded_file is not None: # Display "Classifying..." text st.write("Classifying...") else: # Show a message when no image is uploaded st.write("Upload an image to see the classification result.") # If an image has been uploaded, display it and run the prediction if uploaded_file is not None: # Display the uploaded image in the first column with col1: image = Image.open(uploaded_file).convert('RGB') st.image(image, caption='Uploaded Image', use_column_width=True) # Transform the image and make prediction in the second column with col2: image_tensor = transform_image(image) predicted_class, probability = predict(model, image_tensor) st.write(f'Predicted class: {predicted_class}') st.write(f'Probability: {probability:.3f}')