import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms, models from PIL import Image, UnidentifiedImageError import streamlit as st import numpy as np import requests from io import BytesIO from kan_linear import KANLinear import logging import os # Setup logging logging.basicConfig(level=logging.INFO) # Define the model class KANVGG16(nn.Module): def __init__(self, num_classes=1): # For binary classification (cats and dogs) super(KANVGG16, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.BatchNorm2d(64), # Added Batch Normalization nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.BatchNorm2d(128), # Added Batch Normalization nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.BatchNorm2d(256), # Added Batch Normalization nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.BatchNorm2d(512), # Added Batch Normalization nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.BatchNorm2d(512), # Added Batch Normalization ) self.classifier = nn.Sequential( KANLinear(512 * 7 * 7, 2048), # Adjusted for input size 224x224 nn.ReLU(inplace=True), nn.Dropout(0.5), # Increased Dropout KANLinear(2048, 2048), nn.ReLU(inplace=True), nn.Dropout(0.5), # Increased Dropout KANLinear(2048, num_classes) ) def forward(self, x): x = self.features(x) x = torch.flatten(x, 1) x = self.classifier(x) return x def load_model(weights_path, device): model = KANVGG16().to(device) state_dict = torch.load(weights_path, map_location=device) # Remove 'module.' prefix from keys new_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): new_state_dict[k[len('module.'):]] = v else: new_state_dict[k] = v model.load_state_dict(new_state_dict) model.eval() return model class CustomImageLoadingError(Exception): """Custom exception for image loading errors""" pass def load_image_from_url(url): try: logging.info(f"Loading image from URL: {url}") # Check the file extension valid_extensions = ['jpg', 'jpeg', 'png', 'webp'] file_extension = os.path.splitext(url)[1][1:].lower() if file_extension not in valid_extensions: raise CustomImageLoadingError(f"URL does not point to an image with a valid extension: {file_extension}") response = requests.get(url) response.raise_for_status() # Check if the request was successful content_type = response.headers['Content-Type'] logging.info(f"Content-Type: {content_type}") # Check if the content type is an image if 'image' not in content_type: raise CustomImageLoadingError(f"URL does not point to an image: {content_type}") img = Image.open(BytesIO(response.content)).convert('RGB') logging.info("Image successfully loaded and converted to RGB") return img except requests.HTTPError as e: logging.error(f"HTTPError while loading image: {e}") raise CustomImageLoadingError(f"Error loading image from URL: {e}") except UnidentifiedImageError as e: logging.error(f"UnidentifiedImageError while loading image: {e}") raise CustomImageLoadingError(f"Cannot identify image file: {e}") except requests.RequestException as e: logging.error(f"RequestException while loading image: {e}") raise CustomImageLoadingError(f"Error loading image from URL: {e}") except Exception as e: logging.error(f"Unexpected error while loading image: {e}") raise CustomImageLoadingError(f"Error loading image from URL: {e}") def preprocess_image(image): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) return transform(image).unsqueeze(0) # Streamlit app st.title("Cat and Dog Classification with VGG16-KAN") st.sidebar.title("Upload Images") uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"]) image_url = st.sidebar.text_input("Or enter image URL...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = load_model('weights/best_model_VGG16_KAN_97.pth', device) img = None if uploaded_file is not None: logging.info("Image uploaded via file uploader") img = Image.open(uploaded_file).convert('RGB') elif image_url: try: img = load_image_from_url(image_url) except CustomImageLoadingError as e: st.sidebar.error(str(e)) except Exception as e: st.sidebar.error(f"Unexpected error: {e}") st.sidebar.write("-----") # Define your information for the footer name = "Wayan Dadang" st.sidebar.write("Follow me on:") # Create a footer section with links and copyright information st.sidebar.markdown(f""" [LinkedIn](https://www.linkedin.com/in/wayan-dadang-801757116/) [GitHub](https://github.com/Wayan123) [Resume](https://wayan123.github.io/) © {name} - {2024} """, unsafe_allow_html=True) if img is not None: st.image(np.array(img), caption='Uploaded Image.', use_column_width=True) if st.button('Predict'): img_tensor = preprocess_image(img).to(device) with torch.no_grad(): output = model(img_tensor) prob = torch.sigmoid(output).item() st.write(f"Prediction: {prob:.4f}") if prob < 0.5: st.write("This image is classified as a Cat.") else: st.write("This image is classified as a Dog.")