Spaces:
Sleeping
Sleeping
Commit
·
75b806b
1
Parent(s):
9699cad
Update app.py
Browse files
app.py
CHANGED
@@ -2,42 +2,70 @@ import streamlit as st
|
|
2 |
import torch
|
3 |
from PIL import Image
|
4 |
from torchvision import transforms
|
|
|
5 |
|
6 |
-
#
|
|
|
|
|
|
|
|
|
7 |
@st.cache(allow_output_mutation=True)
|
8 |
def load_model():
|
9 |
-
model = torch.load(
|
10 |
model.eval()
|
11 |
return model
|
12 |
|
13 |
model = load_model()
|
14 |
|
15 |
-
# Function to apply transforms to the image
|
16 |
-
def transform_image(image):
|
17 |
transform = transforms.Compose([
|
18 |
-
transforms.Resize(
|
|
|
19 |
transforms.ToTensor(),
|
20 |
-
# Add other transformations as needed
|
21 |
])
|
22 |
return transform(image).unsqueeze(0) # Add batch dimension
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
st.title("Animal Facial Expression Recognition")
|
25 |
|
|
|
|
|
26 |
|
27 |
-
#
|
28 |
-
|
29 |
-
|
30 |
-
image = Image.open(uploaded_file).convert('RGB')
|
31 |
-
st.image(image, caption='Uploaded Image.', use_column_width=True)
|
32 |
-
st.write("")
|
33 |
-
st.write("Classifying...")
|
34 |
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
# Display the prediction (modify as per your output)
|
43 |
-
st.write('Predicted class:', prediction.argmax().item())
|
|
|
2 |
import torch
|
3 |
from PIL import Image
|
4 |
from torchvision import transforms
|
5 |
+
from typing import List, Tuple
|
6 |
|
7 |
+
# Assuming your model and class names are set up correctly
|
8 |
+
pretrained_vit_path = 'pretrained_vit_model_full.pth'
|
9 |
+
class_names = ['Angry', 'Other', 'Sad', 'happy']
|
10 |
+
|
11 |
+
# Load your model
|
12 |
@st.cache(allow_output_mutation=True)
|
13 |
def load_model():
|
14 |
+
model = torch.load(pretrained_vit_path, map_location=torch.device('cpu'))
|
15 |
model.eval()
|
16 |
return model
|
17 |
|
18 |
model = load_model()
|
19 |
|
20 |
+
# Function to apply transforms to the image
|
21 |
+
def transform_image(image, size=(224, 224)):
|
22 |
transform = transforms.Compose([
|
23 |
+
transforms.Resize(size),
|
24 |
+
transforms.CenterCrop(size), # Ensures image is square
|
25 |
transforms.ToTensor(),
|
26 |
+
# Add other transformations as needed, such as normalization
|
27 |
])
|
28 |
return transform(image).unsqueeze(0) # Add batch dimension
|
29 |
|
30 |
+
# Prediction function
|
31 |
+
def predict(model, image_tensor):
|
32 |
+
with torch.no_grad():
|
33 |
+
outputs = model(image_tensor)
|
34 |
+
_, predicted = torch.max(outputs, 1)
|
35 |
+
probabilities = torch.nn.functional.softmax(outputs, dim=1)
|
36 |
+
top_prob, top_catid = torch.topk(probabilities, 1)
|
37 |
+
return class_names[predicted[0]], top_prob[0].item()
|
38 |
+
|
39 |
+
# Streamlit interface
|
40 |
st.title("Animal Facial Expression Recognition")
|
41 |
|
42 |
+
# Create two columns for the layout
|
43 |
+
col1, col2 = st.columns([1, 1])
|
44 |
|
45 |
+
# First column for the uploader
|
46 |
+
with col1:
|
47 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
# Second column for the prediction results
|
50 |
+
with col2:
|
51 |
+
if uploaded_file is not None:
|
52 |
+
# Display "Classifying..." text
|
53 |
+
st.write("Classifying...")
|
54 |
+
else:
|
55 |
+
# Show a message when no image is uploaded
|
56 |
+
st.write("Upload an image to see the classification result.")
|
57 |
|
58 |
+
# If an image has been uploaded, display it and run the prediction
|
59 |
+
if uploaded_file is not None:
|
60 |
+
# Display the uploaded image in the first column
|
61 |
+
with col1:
|
62 |
+
image = Image.open(uploaded_file).convert('RGB')
|
63 |
+
st.image(image, caption='Uploaded Image', use_column_width=True)
|
64 |
+
|
65 |
+
# Transform the image and make prediction in the second column
|
66 |
+
with col2:
|
67 |
+
image_tensor = transform_image(image)
|
68 |
+
predicted_class, probability = predict(model, image_tensor)
|
69 |
+
st.write(f'Predicted class: {predicted_class}')
|
70 |
+
st.write(f'Probability: {probability:.3f}')
|
71 |
|
|
|
|