wayandadang's picture
Update app.py
0d173bb verified
raw
history blame contribute delete
No virus
7.01 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image, UnidentifiedImageError
import streamlit as st
import numpy as np
import requests
from io import BytesIO
from kan_linear import KANLinear
import logging
import os
# Setup logging
logging.basicConfig(level=logging.INFO)
# Define the model
class KANVGG16(nn.Module):
def __init__(self, num_classes=1): # For binary classification (cats and dogs)
super(KANVGG16, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(64), # Added Batch Normalization
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(128), # Added Batch Normalization
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(256), # Added Batch Normalization
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(512), # Added Batch Normalization
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(512), # Added Batch Normalization
)
self.classifier = nn.Sequential(
KANLinear(512 * 7 * 7, 2048), # Adjusted for input size 224x224
nn.ReLU(inplace=True),
nn.Dropout(0.5), # Increased Dropout
KANLinear(2048, 2048),
nn.ReLU(inplace=True),
nn.Dropout(0.5), # Increased Dropout
KANLinear(2048, num_classes)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def load_model(weights_path, device):
model = KANVGG16().to(device)
state_dict = torch.load(weights_path, map_location=device)
# Remove 'module.' prefix from keys
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[len('module.'):]] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
model.eval()
return model
class CustomImageLoadingError(Exception):
"""Custom exception for image loading errors"""
pass
def load_image_from_url(url):
try:
logging.info(f"Loading image from URL: {url}")
# Check the file extension
valid_extensions = ['jpg', 'jpeg', 'png', 'webp']
file_extension = os.path.splitext(url)[1][1:].lower()
if file_extension not in valid_extensions:
raise CustomImageLoadingError(f"URL does not point to an image with a valid extension: {file_extension}")
response = requests.get(url)
response.raise_for_status() # Check if the request was successful
content_type = response.headers['Content-Type']
logging.info(f"Content-Type: {content_type}")
# Check if the content type is an image
if 'image' not in content_type:
raise CustomImageLoadingError(f"URL does not point to an image: {content_type}")
img = Image.open(BytesIO(response.content)).convert('RGB')
logging.info("Image successfully loaded and converted to RGB")
return img
except requests.HTTPError as e:
logging.error(f"HTTPError while loading image: {e}")
raise CustomImageLoadingError(f"Error loading image from URL: {e}")
except UnidentifiedImageError as e:
logging.error(f"UnidentifiedImageError while loading image: {e}")
raise CustomImageLoadingError(f"Cannot identify image file: {e}")
except requests.RequestException as e:
logging.error(f"RequestException while loading image: {e}")
raise CustomImageLoadingError(f"Error loading image from URL: {e}")
except Exception as e:
logging.error(f"Unexpected error while loading image: {e}")
raise CustomImageLoadingError(f"Error loading image from URL: {e}")
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
return transform(image).unsqueeze(0)
# Streamlit app
st.title("Cat and Dog Classification with VGG16-KAN")
st.sidebar.title("Upload Images")
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
image_url = st.sidebar.text_input("Or enter image URL...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model('weights/best_model_VGG16_KAN_97.pth', device)
img = None
if uploaded_file is not None:
logging.info("Image uploaded via file uploader")
img = Image.open(uploaded_file).convert('RGB')
elif image_url:
try:
img = load_image_from_url(image_url)
except CustomImageLoadingError as e:
st.sidebar.error(str(e))
except Exception as e:
st.sidebar.error(f"Unexpected error: {e}")
st.sidebar.write("-----")
# Define your information for the footer
name = "Wayan Dadang"
st.sidebar.write("Follow me on:")
# Create a footer section with links and copyright information
st.sidebar.markdown(f"""
[LinkedIn](https://www.linkedin.com/in/wayan-dadang-801757116/)
[GitHub](https://github.com/Wayan123)
[Resume](https://wayan123.github.io/)
© {name} - {2024}
""", unsafe_allow_html=True)
if img is not None:
st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
if st.button('Predict'):
img_tensor = preprocess_image(img).to(device)
with torch.no_grad():
output = model(img_tensor)
prob = torch.sigmoid(output).item()
st.write(f"Prediction: {prob:.4f}")
if prob < 0.5:
st.write("This image is classified as a Cat.")
else:
st.write("This image is classified as a Dog.")