Deblur / app.py
pratyyush's picture
Update app.py
44edf20 verified
import os
import gradio as gr
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
import torchvision.transforms.functional as F
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class DeblurNet(nn.Module):
def __init__(self):
super(DeblurNet, self).__init__()
# Matches your checkpoint architecture
self.enc_conv1 = self.conv_block(3, 64)
self.enc_conv2 = self.conv_block(64, 128)
self.enc_conv3 = self.conv_block(128, 256)
# CHANGED: Back to 512 to match your checkpoint
self.bottleneck = self.conv_block(256, 512)
self.dec_conv1 = self.conv_block(512 + 256, 256)
self.dec_conv2 = self.conv_block(256 + 128, 128)
self.dec_conv3 = self.conv_block(128 + 64, 64)
self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
x1 = self.enc_conv1(x)
x2 = self.pool(x1)
x2 = self.enc_conv2(x2)
x3 = self.pool(x2)
x3 = self.enc_conv3(x3)
x4 = self.pool(x3)
x4 = self.bottleneck(x4)
x = self.upsample(x4)
x = torch.cat([x, x3], dim=1)
x = self.dec_conv1(x)
x = self.upsample(x)
x = torch.cat([x, x2], dim=1)
x = self.dec_conv2(x)
x = self.upsample(x)
x = torch.cat([x, x1], dim=1)
x = self.dec_conv3(x)
return torch.tanh(self.final_conv(x))
# Load model
model = DeblurNet().to(device)
model_path = os.path.join('model', 'best_deblur_model.pth')
if os.path.exists(model_path):
# This should now load perfectly because the shapes match
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print("Model loaded successfully!")
else:
print("Model file not found. Please upload 'best_deblur_model.pth' to the /model folder.")
# --- DYNAMIC PROCESSING ---
def deblur_image(img):
if img is None: return None
input_image = Image.fromarray(img).convert("RGB")
w, h = input_image.size
# We must pad to multiple of 8 because we have 3 pooling layers (2^3 = 8)
new_w = (w // 8) * 8
new_h = (h // 8) * 8
input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS)
input_tensor = F.to_tensor(input_image).unsqueeze(0).to(device)
input_tensor = F.normalize(input_tensor, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
with torch.no_grad():
output_tensor = model(input_tensor)
output_tensor = output_tensor.squeeze(0).cpu()
output_tensor = torch.clamp((output_tensor * 0.5 + 0.5), 0, 1)
output_image = F.to_pil_image(output_tensor)
return np.array(output_image.resize((w, h), resample=Image.LANCZOS))
custom_css = "footer {visibility: hidden} .gradio-container {background-color: #000 !important;}"
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("# ✨ AI Image Deblur")
with gr.Row():
in_img = gr.Image(label="Blurry Image")
out_img = gr.Image(label="Deblurred")
btn = gr.Button("Process", variant="primary")
btn.click(deblur_image, inputs=in_img, outputs=out_img)
demo.launch()