wayandadang commited on
Commit
12167ec
1 Parent(s): 5c2bda4

update app.py

Browse files
Files changed (2) hide show
  1. app-old.txt +128 -0
  2. app.py +1 -1
app-old.txt ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ import streamlit as st
10
+ import numpy as np
11
+ import requests
12
+ from io import BytesIO
13
+ from kan_linear import KANLinear
14
+
15
+ class CNNKAN(nn.Module):
16
+ def __init__(self):
17
+ super(CNNKAN, self).__init__()
18
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
19
+ self.bn1 = nn.BatchNorm2d(32)
20
+ self.pool1 = nn.MaxPool2d(2)
21
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
22
+ self.bn2 = nn.BatchNorm2d(64)
23
+ self.pool2 = nn.MaxPool2d(2)
24
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
25
+ self.bn3 = nn.BatchNorm2d(128)
26
+ self.pool3 = nn.MaxPool2d(2)
27
+ self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
28
+ self.bn4 = nn.BatchNorm2d(256)
29
+ self.pool4 = nn.MaxPool2d(2)
30
+ self.dropout = nn.Dropout(0.5)
31
+ self.kan1 = KANLinear(256 * 12 * 12, 512)
32
+ self.kan2 = KANLinear(512, 1)
33
+
34
+ def forward(self, x):
35
+ x = F.selu(self.bn1(self.conv1(x)))
36
+ x = self.pool1(x)
37
+ x = F.selu(self.bn2(self.conv2(x)))
38
+ x = self.pool2(x)
39
+ x = F.selu(self.bn3(self.conv3(x)))
40
+ x = self.pool3(x)
41
+ x = F.selu(self.bn4(self.conv4(x)))
42
+ x = self.pool4(x)
43
+ x = x.view(x.size(0), -1)
44
+ x = self.dropout(x)
45
+ x = self.kan1(x)
46
+ x = self.dropout(x)
47
+ x = self.kan2(x)
48
+ return x
49
+
50
+ def load_model(weights_path, device):
51
+ model = CNNKAN().to(device)
52
+ state_dict = torch.load(weights_path, map_location=device)
53
+
54
+ # Remove 'module.' prefix from keys
55
+ from collections import OrderedDict
56
+ new_state_dict = OrderedDict()
57
+ for k, v in state_dict.items():
58
+ if k.startswith('module.'):
59
+ new_state_dict[k[7:]] = v
60
+ else:
61
+ new_state_dict[k] = v
62
+
63
+ model.load_state_dict(new_state_dict)
64
+ model.eval()
65
+ return model
66
+
67
+ def load_image_from_url(url):
68
+ response = requests.get(url)
69
+ img = Image.open(BytesIO(response.content)).convert('RGB')
70
+ return img
71
+
72
+ def preprocess_image(image):
73
+ transform = transforms.Compose([
74
+ transforms.Resize((200, 200)),
75
+ transforms.ToTensor()
76
+ ])
77
+ return transform(image).unsqueeze(0)
78
+
79
+ # Streamlit app
80
+ st.title("Cat and Dog Classification with CNN-KAN")
81
+
82
+ st.sidebar.title("Upload Images")
83
+ uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
84
+ image_url = st.sidebar.text_input("Or enter image URL...")
85
+
86
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
87
+ model = load_model('weights/best_model_weights_KAN.pth', device)
88
+
89
+ img = None
90
+
91
+ if uploaded_file is not None:
92
+ img = Image.open(uploaded_file).convert('RGB')
93
+ elif image_url:
94
+ try:
95
+ img = load_image_from_url(image_url)
96
+ except Exception as e:
97
+ st.sidebar.error(f"Error loading image from URL: {e}")
98
+
99
+ st.sidebar.write("-----")
100
+
101
+ # Define your information for the footer
102
+ name = "Wayan Dadang"
103
+
104
+ st.sidebar.write("Follow me on:")
105
+ # Create a footer section with links and copyright information
106
+ st.sidebar.markdown(f"""
107
+ [LinkedIn](https://www.linkedin.com/in/wayan-dadang-801757116/)
108
+ [GitHub](https://github.com/Wayan123)
109
+ [Resume](https://wayan123.github.io/)
110
+ © {name} - {2024}
111
+ """, unsafe_allow_html=True)
112
+
113
+ if img is not None:
114
+ st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
115
+ if st.button('Predict'):
116
+ img_tensor = preprocess_image(img).to(device)
117
+
118
+ with torch.no_grad():
119
+ output = model(img_tensor)
120
+ prob = torch.sigmoid(output).item()
121
+
122
+ st.write(f"Prediction: {prob:.4f}")
123
+
124
+ if prob < 0.5:
125
+ st.write("This image is classified as a Cat.")
126
+ else:
127
+ st.write("This image is classified as a Dog")
128
+
app.py CHANGED
@@ -107,7 +107,7 @@ def load_image_from_url(url):
107
 
108
  def preprocess_image(image):
109
  transform = transforms.Compose([
110
- transforms.Resize((224, 224)),
111
  transforms.ToTensor()
112
  ])
113
  return transform(image).unsqueeze(0)
 
107
 
108
  def preprocess_image(image):
109
  transform = transforms.Compose([
110
+ transforms.Resize((200, 200)),
111
  transforms.ToTensor()
112
  ])
113
  return transform(image).unsqueeze(0)