Spaces:
Runtime error
Runtime error
File size: 7,545 Bytes
35e5468 9eae6e7 bd11a0f ba9718c 9eae6e7 ba9718c 9eae6e7 ba9718c 9e43b47 a5c1cde ba9718c a5c1cde ba9718c 9eae6e7 ba9718c 9eae6e7 4530503 9eae6e7 4530503 ba9718c 9eae6e7 ba9718c 659310d ba9718c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import subprocess
subprocess.run('sh setup.sh', shell=True)
print("Installed the dependencies!")
from typing import Tuple
import dnnlib
from PIL import Image
import numpy as np
import torch
import legacy
import cv2
from streamlit_drawable_canvas import st_canvas
import streamlit as st
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_idx = None
truncation_psi = 0.1
title = "FcF-Inpainting"
description = "<p style='color:royalblue; font-size: 14px; font-weight: w300;'> \
[Note: The image and mask are resized to 512x512 before inpainting. The <span style='color:#E0B941;'>Run FcF-Inpainting</span> button will automatically appear after you draw a mask.] To use FcF-Inpainting: <br> \
(1) <span style='color:#E0B941;'>Upload an Image</span> or <span style='color:#E0B941;'> select a sample image on the left</span>. <br> \
(2) Adjust the brush stroke width and <span style='color:#E0B941;'>draw the mask on the image</span>. You may also change the drawing tool on the sidebar. <br>\
(3) After drawing a mask, click the <span style='color:#E0B941;'>Run FcF-Inpainting</span> and witness the MAGIC! 🪄 ✨ ✨<br> \
(4) You may <span style='color:#E0B941;'>download/undo/redo/delete</span> the changes on the image using the options below the image box.</p>"
article = "<p style='color: #E0B941; font-size: 16px; font-weight: w500; text-align: center'> <a style='color: #E0B941;' href='https://praeclarumjj3.github.io/fcf-inpainting/' target='_blank'>Project Page</a> | <a style='color: #E0B941;' href='https://github.com/SHI-Labs/FcF-Inpainting' target='_blank'> Keys to Better Image Inpainting: Structure and Texture Go Hand in Hand</a> | <a style='color: #E0B941;' href='https://github.com/SHI-Labs/FcF-Inpainting' target='_blank'>Github</a></p>"
def create_model(network_pkl):
print('Loading networks from "%s"...' % network_pkl)
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'] # type: ignore
G = G.eval().to(device)
netG_params = sum(p.numel() for p in G.parameters())
print("Generator Params: {} M".format(netG_params/1e6))
return G
def fcf_inpaint(G, org_img, erased_img, mask):
label = torch.zeros([1, G.c_dim], device=device)
if G.c_dim != 0:
if class_idx is None:
ValueError("class_idx can't be None.")
label[:, class_idx] = 1
else:
if class_idx is not None:
print ('warn: --class=lbl ignored when running on an unconditional network')
pred_img = G(img=torch.cat([0.5 - mask, erased_img], dim=1), c=label, truncation_psi=truncation_psi, noise_mode='const')
comp_img = mask.to(device) * pred_img + (1 - mask).to(device) * org_img.to(device)
return comp_img
def denorm(img):
img = np.asarray(img[0].cpu(), dtype=np.float32).transpose(1, 2, 0)
img = (img +1) * 127.5
img = np.rint(img).clip(0, 255).astype(np.uint8)
return img
def pil_to_numpy(pil_img: Image) -> Tuple[torch.Tensor, torch.Tensor]:
img = np.array(pil_img)
return torch.from_numpy(img)[None].permute(0, 3, 1, 2).float() / 127.5 - 1
def process_mask(input_img, mask):
rgb = cv2.cvtColor(input_img, cv2.COLOR_RGBA2RGB)
mask = 255 - mask[:,:,3]
mask = (mask > 0) * 1
rgb = np.array(rgb)
mask_tensor = torch.from_numpy(mask).to(torch.float32)
mask_tensor = mask_tensor.unsqueeze(0)
mask_tensor = mask_tensor.unsqueeze(0).to(device)
rgb = rgb.transpose(2,0,1)
rgb = torch.from_numpy(rgb.astype(np.float32)).unsqueeze(0)
rgb = (rgb.to(torch.float32) / 127.5 - 1).to(device)
rgb_erased = rgb.clone()
rgb_erased = rgb_erased * (1 - mask_tensor) # erase rgb
rgb_erased = rgb_erased.to(torch.float32)
rgb_erased = denorm(rgb_erased)
return rgb_erased
def inpaint(input_img, mask, model):
rgb = cv2.cvtColor(input_img, cv2.COLOR_RGBA2RGB)
mask = 255 - mask[:,:,3]
mask = (mask > 0) * 1
rgb = np.array(rgb)
mask_tensor = torch.from_numpy(mask).to(torch.float32)
mask_tensor = mask_tensor.unsqueeze(0)
mask_tensor = mask_tensor.unsqueeze(0).to(device)
rgb = rgb.transpose(2,0,1)
rgb = torch.from_numpy(rgb.astype(np.float32)).unsqueeze(0)
rgb = (rgb.to(torch.float32) / 127.5 - 1).to(device)
rgb_erased = rgb.clone()
rgb_erased = rgb_erased * (1 - mask_tensor) # erase rgb
rgb_erased = rgb_erased.to(torch.float32)
comp_img = fcf_inpaint(G=model, org_img=rgb.to(torch.float32), erased_img=rgb_erased.to(torch.float32), mask=mask_tensor.to(torch.float32))
rgb_erased = denorm(rgb_erased)
comp_img = denorm(comp_img)
return comp_img
def run_app(model):
if "button_id" not in st.session_state:
st.session_state["button_id"] = ""
if "color_to_label" not in st.session_state:
st.session_state["color_to_label"] = {}
image_inpainting(model)
with st.sidebar:
st.markdown("---")
def image_inpainting(model):
if 'reuse_image' not in st.session_state:
st.session_state.reuse_image = None
st.title(title)
st.markdown(article, unsafe_allow_html=True)
st.markdown(description, unsafe_allow_html=True)
image = st.sidebar.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])
sample_image = st.sidebar.radio('Choose a Sample Image', [
'scene-background.png',
'fence-background.png',
'bench.png',
'house.png',
'landscape.png',
'truck.png',
'scenery.png',
'grass-texture.png',
'mapview-texture.png',
])
drawing_mode = st.sidebar.selectbox(
"Drawing tool:", ("freedraw", "line")
)
image = Image.open(image).convert("RGBA") if image else Image.open(f"./test_512/{sample_image}").convert("RGBA")
image = image.resize((512, 512))
width, height = image.size
stroke_width = st.sidebar.slider("Stroke width: ", 1, 100, 20)
canvas_result = st_canvas(
stroke_color="rgba(255, 0, 255, 0.8)",
stroke_width=stroke_width,
background_image=image,
height=height,
width=width,
drawing_mode=drawing_mode,
key="canvas",
)
if canvas_result.image_data is not None and image and len(canvas_result.json_data["objects"]) > 0:
im = canvas_result.image_data.copy()
background = np.where(
(im[:, :, 0] == 0) &
(im[:, :, 1] == 0) &
(im[:, :, 2] == 0)
)
drawing = np.where(
(im[:, :, 0] == 255) &
(im[:, :, 1] == 0) &
(im[:, :, 2] == 255)
)
im[background]=[0,0,0,255]
im[drawing]=[0,0,0,0] #RGBA
if st.button('Run FcF-Inpainting'):
col1, col2 = st.columns([1,1])
with col1:
# if st.button('Show Image with Holes'):
st.write("Masked Image")
mask_show = process_mask(np.array(image), np.array(im))
st.image(mask_show)
with col2:
st.write("Inpainted Image")
inpainted_img = inpaint(np.array(image), np.array(im), model)
st.image(inpainted_img)
if __name__ == "__main__":
st.set_page_config(
page_title="FcF-Inpainting", page_icon=":sparkles:"
)
st.sidebar.subheader("Configuration")
model = create_model("models/places_512.pkl")
run_app(model) |