FcF-Inpainting / app.py
praeclarumjj3's picture
:zap: Build App
9eae6e7
raw history blame
No virus
6.22 kB
from typing import Tuple
import dnnlib
from PIL import Image
import numpy as np
import torch
import legacy
import cv2
import paddlehub as hub
u2net = hub.Module(name='U2Net')
# gradio app imports
import gradio as gr
from torchvision.transforms import ToTensor, ToPILImage
image_to_tensor = ToTensor()
tensor_to_image = ToPILImage()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_idx = None
truncation_psi = 0.1
def create_model(network_pkl):
print('Loading networks from "%s"...' % network_pkl)
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'] # type: ignore
G = G.eval().to(device)
netG_params = sum(p.numel() for p in G.parameters())
print("Generator Params: {} M".format(netG_params/1e6))
return G
def fcf_inpaint(G, org_img, erased_img, mask):
label = torch.zeros([1, G.c_dim], device=device)
if G.c_dim != 0:
if class_idx is None:
ValueError("class_idx can't be None.")
label[:, class_idx] = 1
else:
if class_idx is not None:
print ('warn: --class=lbl ignored when running on an unconditional network')
pred_img = G(img=torch.cat([0.5 - mask, erased_img], dim=1), c=label, truncation_psi=truncation_psi, noise_mode='const')
comp_img = mask.to(device) * pred_img + (1 - mask).to(device) * org_img.to(device)
return comp_img
def show_images(img):
""" Display a batch of images inline. """
return Image.fromarray(img)
def denorm(img):
img = np.asarray(img[0].cpu(), dtype=np.float32).transpose(1, 2, 0)
img = (img +1) * 127.5
img = np.rint(img).clip(0, 255).astype(np.uint8)
return img
def pil_to_numpy(pil_img: Image) -> Tuple[torch.Tensor, torch.Tensor]:
img = np.array(pil_img)
return torch.from_numpy(img)[None].permute(0, 3, 1, 2).float() / 127.5 - 1
def inpaint(input_img, mask, option):
width, height = input_img.size
if option == "Automatic":
result = u2net.Segmentation(
images=[cv2.cvtColor(np.array(input_img), cv2.COLOR_RGB2BGR)],
paths=None,
batch_size=1,
input_size=320,
output_dir='output',
visualization=True)
mask = Image.fromarray(result[0]['mask'])
else:
mask = mask.resize((width,height))
mask = mask.convert('L')
mask = np.array(mask) / 255.
mask = cv2.resize(mask,
(512, 512), interpolation=cv2.INTER_NEAREST)
mask_tensor = torch.from_numpy(mask).to(torch.float32)
mask_tensor = mask_tensor.unsqueeze(0)
mask_tensor = mask_tensor.unsqueeze(0).to(device)
rgb = input_img.convert('RGB')
rgb = np.array(rgb)
rgb = cv2.resize(rgb,
(512, 512), interpolation=cv2.INTER_AREA)
rgb = rgb.transpose(2,0,1)
rgb = torch.from_numpy(rgb.astype(np.float32)).unsqueeze(0)
rgb = (rgb.to(torch.float32) / 127.5 - 1).to(device)
rgb_erased = rgb.clone()
rgb_erased = rgb_erased * (1 - mask_tensor) # erase rgb
rgb_erased = rgb_erased.to(torch.float32)
# model = create_model("models/places_512.pkl")
# comp_img = fcf_inpaint(G=model, org_img=rgb.to(torch.float32), erased_img=rgb_erased.to(torch.float32), mask=mask_tensor.to(torch.float32))
rgb_erased = denorm(rgb_erased)
# comp_img = denorm(comp_img)
return show_images(rgb_erased), show_images(rgb_erased)
gradio_inputs = [gr.inputs.Image(type='pil',
tool=None,
label="Input Image"),
gr.inputs.Image(type='pil',source="canvas", label="Mask", invert_colors=True),
gr.inputs.Radio(choices=["Automatic", "Manual"], type="value", default="Manual", label="Masking Choice")
# gr.inputs.Image(type='pil',
# tool=None,
# label="Mask")]
]
# gradio_outputs = [gr.outputs.Image(label='Auto-Detected Mask (From drawn black pixels)')]
gradio_outputs = [gr.outputs.Image(label='Image with Hole'),
gr.outputs.Image(label='Inpainted Image')]
examples = [['test_512/person512.png', 'test_512/mask_auto.png', 'Automatic'],
['test_512/a_org.png', 'test_512/a_mask.png', 'Manual'],
['test_512/c_org.png', 'test_512/b_mask.png', 'Manual'],
['test_512/b_org.png', 'test_512/c_mask.png', 'Manual'],
['test_512/d_org.png', 'test_512/d_mask.png', 'Manual'],
['test_512/e_org.png', 'test_512/e_mask.png', 'Manual'],
['test_512/f_org.png', 'test_512/f_mask.png', 'Manual'],
['test_512/g_org.png', 'test_512/g_mask.png', 'Manual'],
['test_512/h_org.png', 'test_512/h_mask.png', 'Manual'],
['test_512/i_org.png', 'test_512/i_mask.png', 'Manual']]
title = "FcF-Inpainting"
description = "[Note: Queue time may take upto 20 seconds! The image and mask are resized to 512x512 before inpainting.] To use FcF-Inpainting: \n \
(1) Upload an Image; \n \
(2) Draw (Manual) a Mask on the White Canvas or Generate a mask using U2Net by selecting the Automatic option; \n \
(3) Click on Submit and witness the MAGIC! 🪄 ✨ ✨"
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10741' target='_blank'> Keys to Better Image Inpainting: Structure and Texture Go Hand in Hand</a> | <a href='https://github.com/SHI-Labs/FcF-Inpainting' target='_blank'>Github Repo</a></p>"
css = ".image-preview {height: 32rem; width: auto;} .output-image {height: 32rem; width: auto;} .panel-buttons { display: flex; flex-direction: row;}"
iface = gr.Interface(fn=inpaint, inputs=gradio_inputs,
outputs=gradio_outputs,
css=css,
layout="vertical",
examples_per_page=5,
thumbnail="fcf_gan.png",
allow_flagging="never",
examples=examples, title=title,
description=description, article=article)
iface.launch(enable_queue=True,
share=True, server_name="0.0.0.0")