aliicemill commited on
Commit
e935a59
1 Parent(s): 8da80ad

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ # Save this file as streamlit.py and run it using the command: streamlit run streamlit.py
4
+
5
+ import streamlit as st
6
+ import torch
7
+ from torch import nn
8
+ from PIL import Image
9
+ import numpy as np
10
+ import torchvision.transforms as transforms
11
+
12
+ # Custom model definition
13
+ class CustomModel(nn.Module):
14
+ def __init__(self):
15
+ super(CustomModel, self).__init__()
16
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
17
+ self.relu = nn.ReLU()
18
+ self.fc = nn.Linear(32 * 224 * 224, 90) # Adjust the dimensions and number of classes if necessary
19
+
20
+ def forward(self, x):
21
+ x = self.conv1(x)
22
+ x = self.relu(x)
23
+ x = x.view(x.size(0), -1) # Flatten the tensor
24
+ x = self.fc(x)
25
+ return x
26
+
27
+ # Function to process the image
28
+ def process_image(img):
29
+ transform = transforms.Compose([
30
+ transforms.Resize((224, 224)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
33
+ ])
34
+ img = transform(img)
35
+ img = img.unsqueeze(0) # Add batch dimension
36
+ return img
37
+
38
+ st.title('Animal Classification')
39
+ st.write('Please choose an image so that the AI model can predict the type of animal.')
40
+ file = st.file_uploader('Pick an image', type=['jpg', 'jpeg', 'png'])
41
+
42
+ # Load animal names
43
+ with open("name of the animals.txt") as f:
44
+ class_names = [x.strip() for x in f.readlines()]
45
+
46
+ if file is not None:
47
+ img = Image.open(file)
48
+ st.image(img, caption='The image: ')
49
+ image = process_image(img)
50
+
51
+ # Load the model
52
+ model = CustomModel()
53
+ model.load_state_dict(torch.load('model-CNN.pth', map_location=torch.device('cpu'))) # Load state dict
54
+ model.eval()
55
+
56
+ # Predict with the model
57
+ with torch.no_grad():
58
+ prediction = model(image)
59
+
60
+ predicted_class = torch.argmax(prediction, dim=1).item()
61
+ st.write('Probability Distribution')
62
+ st.write(prediction.numpy())
63
+ st.write("Prediction: ", class_names[predicted_class])