| import os |
| import gradio as gr |
| import torch |
| import torch.nn as nn |
| from PIL import Image |
| import numpy as np |
| import torchvision.transforms.functional as F |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| class DeblurNet(nn.Module): |
| def __init__(self): |
| super(DeblurNet, self).__init__() |
| |
| self.enc_conv1 = self.conv_block(3, 64) |
| self.enc_conv2 = self.conv_block(64, 128) |
| self.enc_conv3 = self.conv_block(128, 256) |
| |
| |
| self.bottleneck = self.conv_block(256, 512) |
| |
| self.dec_conv1 = self.conv_block(512 + 256, 256) |
| self.dec_conv2 = self.conv_block(256 + 128, 128) |
| self.dec_conv3 = self.conv_block(128 + 64, 64) |
| |
| self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1) |
| self.pool = nn.MaxPool2d(2, 2) |
| self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
|
|
| def conv_block(self, in_channels, out_channels): |
| return nn.Sequential( |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True) |
| ) |
|
|
| def forward(self, x): |
| x1 = self.enc_conv1(x) |
| x2 = self.pool(x1) |
| x2 = self.enc_conv2(x2) |
| x3 = self.pool(x2) |
| x3 = self.enc_conv3(x3) |
| x4 = self.pool(x3) |
| |
| x4 = self.bottleneck(x4) |
| |
| x = self.upsample(x4) |
| x = torch.cat([x, x3], dim=1) |
| x = self.dec_conv1(x) |
| |
| x = self.upsample(x) |
| x = torch.cat([x, x2], dim=1) |
| x = self.dec_conv2(x) |
| |
| x = self.upsample(x) |
| x = torch.cat([x, x1], dim=1) |
| x = self.dec_conv3(x) |
| |
| return torch.tanh(self.final_conv(x)) |
|
|
| |
| model = DeblurNet().to(device) |
| model_path = os.path.join('model', 'best_deblur_model.pth') |
|
|
| if os.path.exists(model_path): |
| |
| model.load_state_dict(torch.load(model_path, map_location=device)) |
| model.eval() |
| print("Model loaded successfully!") |
| else: |
| print("Model file not found. Please upload 'best_deblur_model.pth' to the /model folder.") |
|
|
| |
| def deblur_image(img): |
| if img is None: return None |
| |
| input_image = Image.fromarray(img).convert("RGB") |
| w, h = input_image.size |
| |
| |
| new_w = (w // 8) * 8 |
| new_h = (h // 8) * 8 |
| input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS) |
|
|
| input_tensor = F.to_tensor(input_image).unsqueeze(0).to(device) |
| input_tensor = F.normalize(input_tensor, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
|
|
| with torch.no_grad(): |
| output_tensor = model(input_tensor) |
|
|
| output_tensor = output_tensor.squeeze(0).cpu() |
| output_tensor = torch.clamp((output_tensor * 0.5 + 0.5), 0, 1) |
| output_image = F.to_pil_image(output_tensor) |
| |
| return np.array(output_image.resize((w, h), resample=Image.LANCZOS)) |
|
|
| custom_css = "footer {visibility: hidden} .gradio-container {background-color: #000 !important;}" |
|
|
| with gr.Blocks(css=custom_css) as demo: |
| gr.Markdown("# ✨ AI Image Deblur") |
| with gr.Row(): |
| in_img = gr.Image(label="Blurry Image") |
| out_img = gr.Image(label="Deblurred") |
| btn = gr.Button("Process", variant="primary") |
| btn.click(deblur_image, inputs=in_img, outputs=out_img) |
|
|
| demo.launch() |