aliicemill commited on
Commit
8da80ad
1 Parent(s): 0944830

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -65
app.py DELETED
@@ -1,65 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
- # Save this file as streamlit.py and run it using the command: streamlit run streamlit.py
4
-
5
- import streamlit as st
6
- import torch
7
- from torch import nn
8
- from PIL import Image
9
- import numpy as np
10
- import torchvision.transforms as transforms
11
-
12
- # Define the model class exactly as it was during training if necessary
13
- class CustomModel(nn.Module):
14
- def __init__(self):
15
- super(CustomModel, self).__init__()
16
- # Define layers here exactly as they were in the training script
17
- # For example:
18
- self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
19
- self.relu = nn.ReLU()
20
- self.fc = nn.Linear(32 * 224 * 224, 90) # Adjust the dimensions and number of classes if necessary
21
-
22
- def forward(self, x):
23
- x = self.conv1(x)
24
- x = self.relu(x)
25
- x = x.view(x.size(0), -1) # Flatten the tensor
26
- x = self.fc(x)
27
- return x
28
-
29
- # Load the model
30
- model = CustomModel()
31
- model.load_state_dict(torch.load('model-CNN.pth')) # or model-CNN.pt if saved with state_dict
32
- model.eval()
33
-
34
- # Function to process the image
35
- def process_image(img):
36
- transform = transforms.Compose([
37
- transforms.Resize((224, 224)),
38
- transforms.ToTensor(),
39
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
40
- ])
41
- img = transform(img)
42
- img = img.unsqueeze(0) # Add batch dimension
43
- return img
44
-
45
- st.title('Animal Classification')
46
- st.write('Please choose an image so that the AI model can predict the type of animal.')
47
- file = st.file_uploader('Pick an image', type=['jpg', 'jpeg', 'png'])
48
-
49
- # Load animal names
50
- with open("name of the animals.txt") as f:
51
- class_names = [x.strip() for x in f.readlines()]
52
-
53
- if file is not None:
54
- img = Image.open(file)
55
- st.image(img, caption='The image: ')
56
- image = process_image(img)
57
-
58
- # Predict with the model
59
- with torch.no_grad():
60
- prediction = model(image)
61
-
62
- predicted_class = torch.argmax(prediction, dim=1).item()
63
- st.write('Probability Distribution')
64
- st.write(prediction.numpy())
65
- st.write("Prediction: ", class_names[predicted_class])