yavuzssvr's picture
Update app.py
d119102 verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
import cv2
from efficientnet_pytorch import EfficientNet
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet101
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
deeplab = deeplabv3_resnet101(pretrained=True).to(device)
deeplab.eval()
import torch
import torch.nn as nn
import torch.nn.functional as F
from efficientnet_pytorch import EfficientNet
class EfficientUNetWithSeg(nn.Module):
def __init__(self, n_classes=313):
super().__init__()
self.encoder = EfficientNet.from_pretrained('efficientnet-b0')
self.input_conv = nn.Conv2d(2, 3, kernel_size=1)
self.enc1 = nn.Sequential(
self.encoder._conv_stem,
self.encoder._bn0,
self.encoder._swish
) # [B,32,H/2,W/2]
self.enc2 = nn.Sequential(*self.encoder._blocks[0:2]) # [B,24,H/4,W/4]
self.enc3 = nn.Sequential(*self.encoder._blocks[2:4]) # [B,40,H/8,W/8]
self.enc4 = nn.Sequential(*self.encoder._blocks[4:10]) # [B,80,H/16,W/16]
self.enc5 = nn.Sequential(*self.encoder._blocks[10:]) # [B,112,H/32,W/32]
# Decoder (U-Net style)
self.up4 = self._up_block(320, 112)
self.up3 = self._up_block(112, 40)
self.up2 = self._up_block(40, 24)
self.up1 = self._up_block(24, 32)
# Segmentasyon maskesi embedding
self.seg_embed = nn.Embedding(21, 1) # [B,1,H,W]
# Final prediction
self.final_conv = nn.Conv2d(32, 2, kernel_size=1)
self.upsample_final = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
def _up_block(self, in_ch, out_ch):
return nn.Sequential(
nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2),
nn.ReLU(inplace=True)
)
def forward(self, l, seg_mask):
"""
l: [B, 1, H, W] -> L kanalı (grayscale)
seg_mask: [B, H, W] -> segmentasyon maskesi (long/int türünde sınıf ID’leri)
"""
seg_emb = self.seg_embed(seg_mask.long()) # [B, H, W, 1]
seg_emb = seg_emb.permute(0, 3, 1, 2) # [B, 1, H, W]
x = torch.cat([l, seg_emb], dim=1) # [B, 2, H, W]
x = self.input_conv(x) # [B, 3, H, W]
x1 = self.enc1(x)
x2 = self.enc2(x1)
x3 = self.enc3(x2)
x4 = self.enc4(x3)
x5 = self.enc5(x4)
# Decoder
u4 = self.up4(x5) + x4
u3 = self.up3(u4) + x3
u2 = self.up2(u3) + x2
u1 = self.up1(u2) + x1
out = self.final_conv(u1)
out = self.upsample_final(out)
return out
def get_segmentation_mask_from_np(rgb_np):
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
input_tensor = transform(rgb_np).unsqueeze(0).to(device)
with torch.no_grad():
output = deeplab(input_tensor)['out'][0]
seg_mask = output.argmax(0).cpu().numpy()
return seg_mask
def lab_to_rgb(L, ab):
if len(L.shape) == 3:
L = L[0]
L = (L * 255.0).astype(np.uint8)
a = (ab[0] * 127.0 + 128).astype(np.uint8)
b = (ab[1] * 127.0 + 128).astype(np.uint8)
lab = np.stack([L, a, b], axis=2)
rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
return rgb
def colorize_image(gray_img_pil, mode, file_format):
if file_format.upper() == "JPG":
file_format = "JPEG"
elif file_format.upper() == "JPG":
file_format = "JPEG"
elif file_format.upper() == "WEBP":
file_format = "WEBP"
elif file_format.upper() == "TIFF":
file_format = "TIFF"
gray_np_original = np.array(gray_img_pil.convert("L")) # (H, W)
orig_h, orig_w = gray_np_original.shape
gray_resized = cv2.resize(gray_np_original, (256, 256)) / 255.0
L_tensor = torch.tensor(gray_resized).unsqueeze(0).unsqueeze(0).float().to(device)
# 3. Segmentasyon maskesi için sahte RGB oluştur ve maskeyi al
rgb_simulated = cv2.cvtColor(gray_np_original, cv2.COLOR_GRAY2RGB)
rgb_resized = cv2.resize(rgb_simulated, (256, 256))
seg_mask = get_segmentation_mask_from_np(rgb_resized)
seg_tensor = torch.tensor(seg_mask).unsqueeze(0).to(device)
with torch.no_grad():
ab_pred = model(L_tensor, seg_tensor)
ab_pred_np = ab_pred[0].cpu().numpy() # (2, 256, 256)
ab_resized = np.stack([
cv2.resize(ab_pred_np[0], (orig_w, orig_h), interpolation=cv2.INTER_CUBIC),
cv2.resize(ab_pred_np[1], (orig_w, orig_h), interpolation=cv2.INTER_CUBIC)
], axis=0) # (2, H, W)
L_bgr = cv2.cvtColor(gray_np_original, cv2.COLOR_GRAY2BGR)
L_lab = cv2.cvtColor(L_bgr, cv2.COLOR_BGR2LAB)
L_full = L_lab[:, :, 0] / 255.0 # (H, W), float
rgb_output = lab_to_rgb(L_full, ab_resized)
input_show = Image.fromarray(gray_np_original).convert("RGB")
output_show = Image.fromarray(rgb_output)
save_path = f"/tmp/output_colored.{file_format.lower()}"
output_show.save(save_path, format=file_format)
return [input_show, output_show], save_path
model = EfficientUNetWithSeg()
model.load_state_dict(torch.load("best_model_earlystop_BESTMODEL.pth", map_location=device))
model.to(device)
model.eval()
# Gradio arayüz
with gr.Blocks(theme="soft") as demo:
gr.Markdown("## 🎨 AI-Powered Image Colorization")
gr.Markdown("Colorize black-and-white images using a segmentation-assisted EfficientUNet model.")
input_image = gr.Image(label="🖼️ Upload Grayscale Image", type="pil")
with gr.Row():
mode = gr.Radio(["Basic", "Advanced"], value="Basic", label="🧭 Mode")
file_format = gr.Radio(["PNG", "JPG", "WEBP", "TIFF"], value="PNG", label="🗂️ Output Format")
run_button = gr.Button("🚀 Colorize")
output_gallery = gr.Gallery(label="🎬 Before and After", columns=2, height=300)
download_button = gr.File(label="⬇ Download Colorized Image")
def process_wrapper(img, mode, fmt):
try:
gallery, path = colorize_image(img, mode, fmt)
return gallery, path
except Exception as e:
import traceback
print("🔥 ERROR:\n", traceback.format_exc())
raise gr.Error(f"An error occurred:\n{str(e)}")
run_button.click(fn=process_wrapper,
inputs=[input_image, mode, file_format],
outputs=[output_gallery, download_button])
demo.launch()