mandali8686 commited on
Commit
75b806b
1 Parent(s): 9699cad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -20
app.py CHANGED
@@ -2,42 +2,70 @@ import streamlit as st
2
  import torch
3
  from PIL import Image
4
  from torchvision import transforms
 
5
 
6
- # Load your model (ensure this is the correct path to your model file)
 
 
 
 
7
  @st.cache(allow_output_mutation=True)
8
  def load_model():
9
- model = torch.load('pretrained_vit_model_full.pth', map_location=torch.device('cpu'))
10
  model.eval()
11
  return model
12
 
13
  model = load_model()
14
 
15
- # Function to apply transforms to the image (update as per your model's requirement)
16
- def transform_image(image):
17
  transform = transforms.Compose([
18
- transforms.Resize((224, 224)), # Resize to the input size that your model expects
 
19
  transforms.ToTensor(),
20
- # Add other transformations as needed
21
  ])
22
  return transform(image).unsqueeze(0) # Add batch dimension
23
 
 
 
 
 
 
 
 
 
 
 
24
  st.title("Animal Facial Expression Recognition")
25
 
 
 
26
 
27
- # File uploader
28
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
29
- if uploaded_file is not None:
30
- image = Image.open(uploaded_file).convert('RGB')
31
- st.image(image, caption='Uploaded Image.', use_column_width=True)
32
- st.write("")
33
- st.write("Classifying...")
34
 
35
- # Transform the image
36
- input_tensor = transform_image(image)
 
 
 
 
 
 
37
 
38
- # Make prediction
39
- with torch.no_grad():
40
- prediction = model(input_tensor)
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Display the prediction (modify as per your output)
43
- st.write('Predicted class:', prediction.argmax().item())
 
2
  import torch
3
  from PIL import Image
4
  from torchvision import transforms
5
+ from typing import List, Tuple
6
 
7
+ # Assuming your model and class names are set up correctly
8
+ pretrained_vit_path = 'pretrained_vit_model_full.pth'
9
+ class_names = ['Angry', 'Other', 'Sad', 'happy']
10
+
11
+ # Load your model
12
  @st.cache(allow_output_mutation=True)
13
  def load_model():
14
+ model = torch.load(pretrained_vit_path, map_location=torch.device('cpu'))
15
  model.eval()
16
  return model
17
 
18
  model = load_model()
19
 
20
+ # Function to apply transforms to the image
21
+ def transform_image(image, size=(224, 224)):
22
  transform = transforms.Compose([
23
+ transforms.Resize(size),
24
+ transforms.CenterCrop(size), # Ensures image is square
25
  transforms.ToTensor(),
26
+ # Add other transformations as needed, such as normalization
27
  ])
28
  return transform(image).unsqueeze(0) # Add batch dimension
29
 
30
+ # Prediction function
31
+ def predict(model, image_tensor):
32
+ with torch.no_grad():
33
+ outputs = model(image_tensor)
34
+ _, predicted = torch.max(outputs, 1)
35
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
36
+ top_prob, top_catid = torch.topk(probabilities, 1)
37
+ return class_names[predicted[0]], top_prob[0].item()
38
+
39
+ # Streamlit interface
40
  st.title("Animal Facial Expression Recognition")
41
 
42
+ # Create two columns for the layout
43
+ col1, col2 = st.columns([1, 1])
44
 
45
+ # First column for the uploader
46
+ with col1:
47
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
 
 
 
 
48
 
49
+ # Second column for the prediction results
50
+ with col2:
51
+ if uploaded_file is not None:
52
+ # Display "Classifying..." text
53
+ st.write("Classifying...")
54
+ else:
55
+ # Show a message when no image is uploaded
56
+ st.write("Upload an image to see the classification result.")
57
 
58
+ # If an image has been uploaded, display it and run the prediction
59
+ if uploaded_file is not None:
60
+ # Display the uploaded image in the first column
61
+ with col1:
62
+ image = Image.open(uploaded_file).convert('RGB')
63
+ st.image(image, caption='Uploaded Image', use_column_width=True)
64
+
65
+ # Transform the image and make prediction in the second column
66
+ with col2:
67
+ image_tensor = transform_image(image)
68
+ predicted_class, probability = predict(model, image_tensor)
69
+ st.write(f'Predicted class: {predicted_class}')
70
+ st.write(f'Probability: {probability:.3f}')
71