import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from PIL import Image import streamlit as st import numpy as np import requests from io import BytesIO from kan_linear import KANLinear class CNNKAN(nn.Module): def __init__(self): super(CNNKAN, self).__init__() self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.pool1 = nn.MaxPool2d(2) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.pool2 = nn.MaxPool2d(2) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm2d(128) self.pool3 = nn.MaxPool2d(2) self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.bn4 = nn.BatchNorm2d(256) self.pool4 = nn.MaxPool2d(2) self.dropout = nn.Dropout(0.5) self.kan1 = KANLinear(256 * 12 * 12, 512) self.kan2 = KANLinear(512, 1) def forward(self, x): x = F.selu(self.bn1(self.conv1(x))) x = self.pool1(x) x = F.selu(self.bn2(self.conv2(x))) x = self.pool2(x) x = F.selu(self.bn3(self.conv3(x))) x = self.pool3(x) x = F.selu(self.bn4(self.conv4(x))) x = self.pool4(x) x = x.view(x.size(0), -1) x = self.dropout(x) x = self.kan1(x) x = self.dropout(x) x = self.kan2(x) return x def load_model(weights_path, device): model = CNNKAN().to(device) state_dict = torch.load(weights_path, map_location=device) # Remove 'module.' prefix from keys from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): if k.startswith('module.'): new_state_dict[k[7:]] = v else: new_state_dict[k] = v model.load_state_dict(new_state_dict) model.eval() return model def load_image_from_url(url): response = requests.get(url) img = Image.open(BytesIO(response.content)).convert('RGB') return img def preprocess_image(image): transform = transforms.Compose([ transforms.Resize((200, 200)), transforms.ToTensor() ]) return transform(image).unsqueeze(0) # Streamlit app st.title("Cat and Dog Classification with CNN-KAN") st.sidebar.title("Upload Images") uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"]) image_url = st.sidebar.text_input("Or enter image URL...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = load_model('weights/best_model_weights_KAN.pth', device) img = None if uploaded_file is not None: img = Image.open(uploaded_file).convert('RGB') elif image_url: try: img = load_image_from_url(image_url) except Exception as e: st.sidebar.error(f"Error loading image from URL: {e}") st.sidebar.write("-----") # Define your information for the footer name = "Wayan Dadang" st.sidebar.write("Follow me on:") # Create a footer section with links and copyright information st.sidebar.markdown(f""" [LinkedIn](https://www.linkedin.com/in/wayan-dadang-801757116/) [GitHub](https://github.com/Wayan123) [Resume](https://wayan123.github.io/) © {name} - {2024} """, unsafe_allow_html=True) if img is not None: st.image(np.array(img), caption='Uploaded Image.', use_column_width=True) if st.button('Predict'): img_tensor = preprocess_image(img).to(device) with torch.no_grad(): output = model(img_tensor) prob = torch.sigmoid(output).item() st.write(f"Prediction: {prob:.4f}") if prob < 0.5: st.write("This image is classified as a Cat.") else: st.write("This image is classified as a Dog")