File size: 7,009 Bytes
fc24292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d173bb
fc24292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image, UnidentifiedImageError
import streamlit as st
import numpy as np
import requests
from io import BytesIO
from kan_linear import KANLinear
import logging
import os

# Setup logging
logging.basicConfig(level=logging.INFO)

# Define the model
class KANVGG16(nn.Module):
    def __init__(self, num_classes=1):  # For binary classification (cats and dogs)
        super(KANVGG16, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(64),  # Added Batch Normalization

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(128),  # Added Batch Normalization

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(256),  # Added Batch Normalization

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(512),  # Added Batch Normalization

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(512),  # Added Batch Normalization
        )
        self.classifier = nn.Sequential(
            KANLinear(512 * 7 * 7, 2048),  # Adjusted for input size 224x224
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),  # Increased Dropout
            KANLinear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),  # Increased Dropout
            KANLinear(2048, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

def load_model(weights_path, device):
    model = KANVGG16().to(device)
    state_dict = torch.load(weights_path, map_location=device)
    
    # Remove 'module.' prefix from keys
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[len('module.'):]] = v
        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict)
    model.eval()
    return model

class CustomImageLoadingError(Exception):
    """Custom exception for image loading errors"""
    pass

def load_image_from_url(url):
    try:
        logging.info(f"Loading image from URL: {url}")

        # Check the file extension
        valid_extensions = ['jpg', 'jpeg', 'png', 'webp']
        file_extension = os.path.splitext(url)[1][1:].lower()
        if file_extension not in valid_extensions:
            raise CustomImageLoadingError(f"URL does not point to an image with a valid extension: {file_extension}")

        response = requests.get(url)
        response.raise_for_status()  # Check if the request was successful

        content_type = response.headers['Content-Type']
        logging.info(f"Content-Type: {content_type}")

        # Check if the content type is an image
        if 'image' not in content_type:
            raise CustomImageLoadingError(f"URL does not point to an image: {content_type}")

        img = Image.open(BytesIO(response.content)).convert('RGB')
        logging.info("Image successfully loaded and converted to RGB")
        return img
    except requests.HTTPError as e:
        logging.error(f"HTTPError while loading image: {e}")
        raise CustomImageLoadingError(f"Error loading image from URL: {e}")
    except UnidentifiedImageError as e:
        logging.error(f"UnidentifiedImageError while loading image: {e}")
        raise CustomImageLoadingError(f"Cannot identify image file: {e}")
    except requests.RequestException as e:
        logging.error(f"RequestException while loading image: {e}")
        raise CustomImageLoadingError(f"Error loading image from URL: {e}")
    except Exception as e:
        logging.error(f"Unexpected error while loading image: {e}")
        raise CustomImageLoadingError(f"Error loading image from URL: {e}")

def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    return transform(image).unsqueeze(0)

# Streamlit app
st.title("Cat and Dog Classification with VGG16-KAN")

st.sidebar.title("Upload Images")
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
image_url = st.sidebar.text_input("Or enter image URL...")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model('weights/best_model_VGG16_KAN_97.pth', device)

img = None

if uploaded_file is not None:
    logging.info("Image uploaded via file uploader")
    img = Image.open(uploaded_file).convert('RGB')
elif image_url:
    try:
        img = load_image_from_url(image_url)
    except CustomImageLoadingError as e:
        st.sidebar.error(str(e))
    except Exception as e:
        st.sidebar.error(f"Unexpected error: {e}")

st.sidebar.write("-----")

# Define your information for the footer
name = "Wayan Dadang"

st.sidebar.write("Follow me on:")
# Create a footer section with links and copyright information
st.sidebar.markdown(f"""
    [LinkedIn](https://www.linkedin.com/in/wayan-dadang-801757116/)
    [GitHub](https://github.com/Wayan123)
    [Resume](https://wayan123.github.io/)
    © {name} - {2024}
    """, unsafe_allow_html=True)

if img is not None:
    st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
    if st.button('Predict'):
        img_tensor = preprocess_image(img).to(device)

        with torch.no_grad():
            output = model(img_tensor)
            prob = torch.sigmoid(output).item()

        st.write(f"Prediction: {prob:.4f}")

        if prob < 0.5:
            st.write("This image is classified as a Cat.")
        else:
            st.write("This image is classified as a Dog.")