|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import cv2 |
|
from efficientnet_pytorch import EfficientNet |
|
from torchvision import transforms |
|
from torchvision.models.segmentation import deeplabv3_resnet101 |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
deeplab = deeplabv3_resnet101(pretrained=True).to(device) |
|
deeplab.eval() |
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from efficientnet_pytorch import EfficientNet |
|
|
|
class EfficientUNetWithSeg(nn.Module): |
|
def __init__(self, n_classes=313): |
|
super().__init__() |
|
|
|
|
|
self.encoder = EfficientNet.from_pretrained('efficientnet-b0') |
|
|
|
|
|
self.input_conv = nn.Conv2d(2, 3, kernel_size=1) |
|
|
|
|
|
self.enc1 = nn.Sequential( |
|
self.encoder._conv_stem, |
|
self.encoder._bn0, |
|
self.encoder._swish |
|
) |
|
self.enc2 = nn.Sequential(*self.encoder._blocks[0:2]) |
|
self.enc3 = nn.Sequential(*self.encoder._blocks[2:4]) |
|
self.enc4 = nn.Sequential(*self.encoder._blocks[4:10]) |
|
self.enc5 = nn.Sequential(*self.encoder._blocks[10:]) |
|
|
|
|
|
self.up4 = self._up_block(320, 112) |
|
self.up3 = self._up_block(112, 40) |
|
self.up2 = self._up_block(40, 24) |
|
self.up1 = self._up_block(24, 32) |
|
|
|
|
|
self.seg_embed = nn.Embedding(21, 1) |
|
|
|
|
|
self.final_conv = nn.Conv2d(32, 2, kernel_size=1) |
|
self.upsample_final = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) |
|
|
|
def _up_block(self, in_ch, out_ch): |
|
return nn.Sequential( |
|
nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2), |
|
nn.ReLU(inplace=True) |
|
) |
|
|
|
def forward(self, l, seg_mask): |
|
""" |
|
l: [B, 1, H, W] -> L kanalı (grayscale) |
|
seg_mask: [B, H, W] -> segmentasyon maskesi (long/int türünde sınıf ID’leri) |
|
""" |
|
|
|
seg_emb = self.seg_embed(seg_mask.long()) |
|
seg_emb = seg_emb.permute(0, 3, 1, 2) |
|
|
|
|
|
x = torch.cat([l, seg_emb], dim=1) |
|
x = self.input_conv(x) |
|
|
|
|
|
x1 = self.enc1(x) |
|
x2 = self.enc2(x1) |
|
x3 = self.enc3(x2) |
|
x4 = self.enc4(x3) |
|
x5 = self.enc5(x4) |
|
|
|
|
|
u4 = self.up4(x5) + x4 |
|
u3 = self.up3(u4) + x3 |
|
u2 = self.up2(u3) + x2 |
|
u1 = self.up1(u2) + x1 |
|
|
|
out = self.final_conv(u1) |
|
out = self.upsample_final(out) |
|
return out |
|
def get_segmentation_mask_from_np(rgb_np): |
|
transform = transforms.Compose([ |
|
transforms.ToPILImage(), |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
]) |
|
input_tensor = transform(rgb_np).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
output = deeplab(input_tensor)['out'][0] |
|
seg_mask = output.argmax(0).cpu().numpy() |
|
return seg_mask |
|
def lab_to_rgb(L, ab): |
|
if len(L.shape) == 3: |
|
L = L[0] |
|
L = (L * 255.0).astype(np.uint8) |
|
a = (ab[0] * 127.0 + 128).astype(np.uint8) |
|
b = (ab[1] * 127.0 + 128).astype(np.uint8) |
|
lab = np.stack([L, a, b], axis=2) |
|
rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) |
|
return rgb |
|
|
|
|
|
def colorize_image(gray_img_pil, mode, file_format): |
|
if file_format.upper() == "JPG": |
|
file_format = "JPEG" |
|
elif file_format.upper() == "JPG": |
|
file_format = "JPEG" |
|
elif file_format.upper() == "WEBP": |
|
file_format = "WEBP" |
|
elif file_format.upper() == "TIFF": |
|
file_format = "TIFF" |
|
|
|
|
|
gray_np_original = np.array(gray_img_pil.convert("L")) |
|
orig_h, orig_w = gray_np_original.shape |
|
|
|
|
|
gray_resized = cv2.resize(gray_np_original, (256, 256)) / 255.0 |
|
L_tensor = torch.tensor(gray_resized).unsqueeze(0).unsqueeze(0).float().to(device) |
|
|
|
|
|
rgb_simulated = cv2.cvtColor(gray_np_original, cv2.COLOR_GRAY2RGB) |
|
rgb_resized = cv2.resize(rgb_simulated, (256, 256)) |
|
seg_mask = get_segmentation_mask_from_np(rgb_resized) |
|
seg_tensor = torch.tensor(seg_mask).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
ab_pred = model(L_tensor, seg_tensor) |
|
ab_pred_np = ab_pred[0].cpu().numpy() |
|
|
|
|
|
ab_resized = np.stack([ |
|
cv2.resize(ab_pred_np[0], (orig_w, orig_h), interpolation=cv2.INTER_CUBIC), |
|
cv2.resize(ab_pred_np[1], (orig_w, orig_h), interpolation=cv2.INTER_CUBIC) |
|
], axis=0) |
|
|
|
|
|
L_bgr = cv2.cvtColor(gray_np_original, cv2.COLOR_GRAY2BGR) |
|
L_lab = cv2.cvtColor(L_bgr, cv2.COLOR_BGR2LAB) |
|
L_full = L_lab[:, :, 0] / 255.0 |
|
|
|
|
|
rgb_output = lab_to_rgb(L_full, ab_resized) |
|
|
|
|
|
input_show = Image.fromarray(gray_np_original).convert("RGB") |
|
output_show = Image.fromarray(rgb_output) |
|
|
|
|
|
save_path = f"/tmp/output_colored.{file_format.lower()}" |
|
output_show.save(save_path, format=file_format) |
|
|
|
return [input_show, output_show], save_path |
|
|
|
|
|
|
|
|
|
model = EfficientUNetWithSeg() |
|
model.load_state_dict(torch.load("best_model_earlystop_BESTMODEL.pth", map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
with gr.Blocks(theme="soft") as demo: |
|
gr.Markdown("## 🎨 AI-Powered Image Colorization") |
|
gr.Markdown("Colorize black-and-white images using a segmentation-assisted EfficientUNet model.") |
|
|
|
input_image = gr.Image(label="🖼️ Upload Grayscale Image", type="pil") |
|
|
|
with gr.Row(): |
|
mode = gr.Radio(["Basic", "Advanced"], value="Basic", label="🧭 Mode") |
|
file_format = gr.Radio(["PNG", "JPG", "WEBP", "TIFF"], value="PNG", label="🗂️ Output Format") |
|
|
|
run_button = gr.Button("🚀 Colorize") |
|
|
|
output_gallery = gr.Gallery(label="🎬 Before and After", columns=2, height=300) |
|
download_button = gr.File(label="⬇ Download Colorized Image") |
|
|
|
def process_wrapper(img, mode, fmt): |
|
try: |
|
|
|
gallery, path = colorize_image(img, mode, fmt) |
|
return gallery, path |
|
except Exception as e: |
|
import traceback |
|
|
|
print("🔥 ERROR:\n", traceback.format_exc()) |
|
|
|
|
|
raise gr.Error(f"An error occurred:\n{str(e)}") |
|
|
|
|
|
run_button.click(fn=process_wrapper, |
|
inputs=[input_image, mode, file_format], |
|
outputs=[output_gallery, download_button]) |
|
|
|
demo.launch() |