mandali8686 commited on
Commit
a63b185
1 Parent(s): fc3627e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -1
app.py CHANGED
@@ -1,4 +1,46 @@
1
  import streamlit as st
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
 
6
+ # Load your model (ensure this is the correct path to your model file)
7
+ @st.cache(allow_output_mutation=True)
8
+ def load_model():
9
+ model = torch.load('pretrained_vit_model_full.pth', map_location=torch.device('cpu'))
10
+ model.eval()
11
+ return model
12
+
13
+ model = load_model()
14
+
15
+ # Function to apply transforms to the image (update as per your model's requirement)
16
+ def transform_image(image):
17
+ transform = transforms.Compose([
18
+ transforms.Resize((224, 224)), # Resize to the input size that your model expects
19
+ transforms.ToTensor(),
20
+ # Add other transformations as needed
21
+ ])
22
+ return transform(image).unsqueeze(0) # Add batch dimension
23
+
24
+ st.title("Animal Facial Expression Recognition")
25
+
26
+ # Slider
27
  x = st.slider('Select a value')
28
+ st.write(x, 'squared is', x * x)
29
+
30
+ # File uploader
31
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
32
+ if uploaded_file is not None:
33
+ image = Image.open(uploaded_file).convert('RGB')
34
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
35
+ st.write("")
36
+ st.write("Classifying...")
37
+
38
+ # Transform the image
39
+ input_tensor = transform_image(image)
40
+
41
+ # Make prediction
42
+ with torch.no_grad():
43
+ prediction = model(input_tensor)
44
+
45
+ # Display the prediction (modify as per your output)
46
+ st.write('Predicted class:', prediction.argmax().item())