File size: 6,818 Bytes
2ccba3b 1f2ad72 d119102 2ccba3b 1f2ad72 d119102 2ccba3b d119102 2ccba3b d119102 2ccba3b d119102 2ccba3b d119102 2ccba3b d119102 2ccba3b d119102 2ccba3b d119102 2ccba3b d119102 2ccba3b d119102 2ccba3b d119102 2ccba3b d119102 2ccba3b d119102 2ccba3b d119102 2ccba3b d119102 2ccba3b d119102 512c07e 2ccba3b 3273c00 2ccba3b d119102 2ccba3b 3273c00 d119102 3273c00 d119102 3273c00 d119102 3273c00 2ccba3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import gradio as gr
import torch
import numpy as np
from PIL import Image
import cv2
from efficientnet_pytorch import EfficientNet
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet101
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
deeplab = deeplabv3_resnet101(pretrained=True).to(device)
deeplab.eval()
import torch
import torch.nn as nn
import torch.nn.functional as F
from efficientnet_pytorch import EfficientNet
class EfficientUNetWithSeg(nn.Module):
def __init__(self, n_classes=313):
super().__init__()
self.encoder = EfficientNet.from_pretrained('efficientnet-b0')
self.input_conv = nn.Conv2d(2, 3, kernel_size=1)
self.enc1 = nn.Sequential(
self.encoder._conv_stem,
self.encoder._bn0,
self.encoder._swish
) # [B,32,H/2,W/2]
self.enc2 = nn.Sequential(*self.encoder._blocks[0:2]) # [B,24,H/4,W/4]
self.enc3 = nn.Sequential(*self.encoder._blocks[2:4]) # [B,40,H/8,W/8]
self.enc4 = nn.Sequential(*self.encoder._blocks[4:10]) # [B,80,H/16,W/16]
self.enc5 = nn.Sequential(*self.encoder._blocks[10:]) # [B,112,H/32,W/32]
# Decoder (U-Net style)
self.up4 = self._up_block(320, 112)
self.up3 = self._up_block(112, 40)
self.up2 = self._up_block(40, 24)
self.up1 = self._up_block(24, 32)
# Segmentasyon maskesi embedding
self.seg_embed = nn.Embedding(21, 1) # [B,1,H,W]
# Final prediction
self.final_conv = nn.Conv2d(32, 2, kernel_size=1)
self.upsample_final = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
def _up_block(self, in_ch, out_ch):
return nn.Sequential(
nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2),
nn.ReLU(inplace=True)
)
def forward(self, l, seg_mask):
"""
l: [B, 1, H, W] -> L kanalı (grayscale)
seg_mask: [B, H, W] -> segmentasyon maskesi (long/int türünde sınıf ID’leri)
"""
seg_emb = self.seg_embed(seg_mask.long()) # [B, H, W, 1]
seg_emb = seg_emb.permute(0, 3, 1, 2) # [B, 1, H, W]
x = torch.cat([l, seg_emb], dim=1) # [B, 2, H, W]
x = self.input_conv(x) # [B, 3, H, W]
x1 = self.enc1(x)
x2 = self.enc2(x1)
x3 = self.enc3(x2)
x4 = self.enc4(x3)
x5 = self.enc5(x4)
# Decoder
u4 = self.up4(x5) + x4
u3 = self.up3(u4) + x3
u2 = self.up2(u3) + x2
u1 = self.up1(u2) + x1
out = self.final_conv(u1)
out = self.upsample_final(out)
return out
def get_segmentation_mask_from_np(rgb_np):
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
input_tensor = transform(rgb_np).unsqueeze(0).to(device)
with torch.no_grad():
output = deeplab(input_tensor)['out'][0]
seg_mask = output.argmax(0).cpu().numpy()
return seg_mask
def lab_to_rgb(L, ab):
if len(L.shape) == 3:
L = L[0]
L = (L * 255.0).astype(np.uint8)
a = (ab[0] * 127.0 + 128).astype(np.uint8)
b = (ab[1] * 127.0 + 128).astype(np.uint8)
lab = np.stack([L, a, b], axis=2)
rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
return rgb
def colorize_image(gray_img_pil, mode, file_format):
if file_format.upper() == "JPG":
file_format = "JPEG"
elif file_format.upper() == "JPG":
file_format = "JPEG"
elif file_format.upper() == "WEBP":
file_format = "WEBP"
elif file_format.upper() == "TIFF":
file_format = "TIFF"
gray_np_original = np.array(gray_img_pil.convert("L")) # (H, W)
orig_h, orig_w = gray_np_original.shape
gray_resized = cv2.resize(gray_np_original, (256, 256)) / 255.0
L_tensor = torch.tensor(gray_resized).unsqueeze(0).unsqueeze(0).float().to(device)
# 3. Segmentasyon maskesi için sahte RGB oluştur ve maskeyi al
rgb_simulated = cv2.cvtColor(gray_np_original, cv2.COLOR_GRAY2RGB)
rgb_resized = cv2.resize(rgb_simulated, (256, 256))
seg_mask = get_segmentation_mask_from_np(rgb_resized)
seg_tensor = torch.tensor(seg_mask).unsqueeze(0).to(device)
with torch.no_grad():
ab_pred = model(L_tensor, seg_tensor)
ab_pred_np = ab_pred[0].cpu().numpy() # (2, 256, 256)
ab_resized = np.stack([
cv2.resize(ab_pred_np[0], (orig_w, orig_h), interpolation=cv2.INTER_CUBIC),
cv2.resize(ab_pred_np[1], (orig_w, orig_h), interpolation=cv2.INTER_CUBIC)
], axis=0) # (2, H, W)
L_bgr = cv2.cvtColor(gray_np_original, cv2.COLOR_GRAY2BGR)
L_lab = cv2.cvtColor(L_bgr, cv2.COLOR_BGR2LAB)
L_full = L_lab[:, :, 0] / 255.0 # (H, W), float
rgb_output = lab_to_rgb(L_full, ab_resized)
input_show = Image.fromarray(gray_np_original).convert("RGB")
output_show = Image.fromarray(rgb_output)
save_path = f"/tmp/output_colored.{file_format.lower()}"
output_show.save(save_path, format=file_format)
return [input_show, output_show], save_path
model = EfficientUNetWithSeg()
model.load_state_dict(torch.load("best_model_earlystop_BESTMODEL.pth", map_location=device))
model.to(device)
model.eval()
# Gradio arayüz
with gr.Blocks(theme="soft") as demo:
gr.Markdown("## 🎨 AI-Powered Image Colorization")
gr.Markdown("Colorize black-and-white images using a segmentation-assisted EfficientUNet model.")
input_image = gr.Image(label="🖼️ Upload Grayscale Image", type="pil")
with gr.Row():
mode = gr.Radio(["Basic", "Advanced"], value="Basic", label="🧭 Mode")
file_format = gr.Radio(["PNG", "JPG", "WEBP", "TIFF"], value="PNG", label="🗂️ Output Format")
run_button = gr.Button("🚀 Colorize")
output_gallery = gr.Gallery(label="🎬 Before and After", columns=2, height=300)
download_button = gr.File(label="⬇ Download Colorized Image")
def process_wrapper(img, mode, fmt):
try:
gallery, path = colorize_image(img, mode, fmt)
return gallery, path
except Exception as e:
import traceback
print("🔥 ERROR:\n", traceback.format_exc())
raise gr.Error(f"An error occurred:\n{str(e)}")
run_button.click(fn=process_wrapper,
inputs=[input_image, mode, file_format],
outputs=[output_gallery, download_button])
demo.launch() |