File size: 7,545 Bytes
35e5468
 
 
 
 
9eae6e7
 
 
 
 
 
bd11a0f
ba9718c
 
9eae6e7
 
 
 
 
ba9718c
 
 
 
 
 
 
 
 
 
 
9eae6e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba9718c
 
 
 
9e43b47
a5c1cde
ba9718c
 
 
a5c1cde
ba9718c
 
 
 
 
 
 
 
 
9eae6e7
ba9718c
 
 
 
 
 
9eae6e7
 
 
 
 
 
 
 
 
 
 
4530503
9eae6e7
4530503
ba9718c
9eae6e7
ba9718c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659310d
ba9718c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import subprocess
subprocess.run('sh setup.sh', shell=True)

print("Installed the dependencies!")

from typing import Tuple
import dnnlib
from PIL import Image
import numpy as np
import torch
import legacy
import cv2
from streamlit_drawable_canvas import st_canvas
import streamlit as st

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_idx = None
truncation_psi = 0.1

title = "FcF-Inpainting"

description = "<p style='color:royalblue; font-size: 14px; font-weight: w300;'>  \
                [Note: The image and mask are resized to 512x512 before inpainting. The <span style='color:#E0B941;'>Run FcF-Inpainting</span> button will automatically appear after you draw a mask.] To use FcF-Inpainting: <br> \
                (1) <span style='color:#E0B941;'>Upload an Image</span> or <span style='color:#E0B941;'> select a sample image on the left</span>. <br>  \
                (2) Adjust the brush stroke width and <span style='color:#E0B941;'>draw the mask on the image</span>. You may also change the drawing tool on the sidebar. <br>\
                (3) After drawing a mask, click the <span style='color:#E0B941;'>Run FcF-Inpainting</span> and witness the MAGIC! 🪄 ✨ ✨<br> \
                (4) You may <span style='color:#E0B941;'>download/undo/redo/delete</span> the changes on the image using the options below the image box.</p>"

article = "<p style='color: #E0B941; font-size: 16px; font-weight: w500; text-align: center'> <a style='color: #E0B941;' href='https://praeclarumjj3.github.io/fcf-inpainting/' target='_blank'>Project Page</a> | <a style='color: #E0B941;' href='https://github.com/SHI-Labs/FcF-Inpainting' target='_blank'> Keys to Better Image Inpainting: Structure and Texture Go Hand in Hand</a> | <a style='color: #E0B941;' href='https://github.com/SHI-Labs/FcF-Inpainting' target='_blank'>Github</a></p>"

def create_model(network_pkl):
    print('Loading networks from "%s"...' % network_pkl)
    with dnnlib.util.open_url(network_pkl) as f:
        G = legacy.load_network_pkl(f)['G_ema'] # type: ignore
    
    G = G.eval().to(device)
    netG_params = sum(p.numel() for p in G.parameters())
    print("Generator Params: {} M".format(netG_params/1e6))
    return G

def fcf_inpaint(G, org_img, erased_img, mask):
    label = torch.zeros([1, G.c_dim], device=device)
    if G.c_dim != 0:
        if class_idx is None:
            ValueError("class_idx can't be None.")
        label[:, class_idx] = 1
    else:
        if class_idx is not None:
            print ('warn: --class=lbl ignored when running on an unconditional network')
    
    pred_img = G(img=torch.cat([0.5 - mask, erased_img], dim=1), c=label, truncation_psi=truncation_psi, noise_mode='const')
    comp_img = mask.to(device) * pred_img + (1 - mask).to(device) * org_img.to(device)
    return comp_img


def denorm(img):
    img = np.asarray(img[0].cpu(), dtype=np.float32).transpose(1, 2, 0)
    img = (img +1) * 127.5
    img = np.rint(img).clip(0, 255).astype(np.uint8)
    return img

def pil_to_numpy(pil_img: Image) -> Tuple[torch.Tensor, torch.Tensor]:
    img = np.array(pil_img)
    return torch.from_numpy(img)[None].permute(0, 3, 1, 2).float() / 127.5 - 1

def process_mask(input_img, mask):
    rgb = cv2.cvtColor(input_img, cv2.COLOR_RGBA2RGB)
    mask = 255 - mask[:,:,3]
    mask = (mask > 0) * 1

    rgb = np.array(rgb)
    mask_tensor = torch.from_numpy(mask).to(torch.float32)
    mask_tensor = mask_tensor.unsqueeze(0)
    mask_tensor = mask_tensor.unsqueeze(0).to(device)

    rgb = rgb.transpose(2,0,1)
    rgb = torch.from_numpy(rgb.astype(np.float32)).unsqueeze(0)
    rgb = (rgb.to(torch.float32) / 127.5 - 1).to(device)
    rgb_erased = rgb.clone()
    rgb_erased = rgb_erased * (1 - mask_tensor) # erase rgb
    rgb_erased = rgb_erased.to(torch.float32)
    
    rgb_erased = denorm(rgb_erased)
    return rgb_erased

def inpaint(input_img, mask, model):
    rgb = cv2.cvtColor(input_img, cv2.COLOR_RGBA2RGB)
    mask = 255 - mask[:,:,3]
    mask = (mask > 0) * 1

    rgb = np.array(rgb)
    mask_tensor = torch.from_numpy(mask).to(torch.float32)
    mask_tensor = mask_tensor.unsqueeze(0)
    mask_tensor = mask_tensor.unsqueeze(0).to(device)

    rgb = rgb.transpose(2,0,1)
    rgb = torch.from_numpy(rgb.astype(np.float32)).unsqueeze(0)
    rgb = (rgb.to(torch.float32) / 127.5 - 1).to(device)
    rgb_erased = rgb.clone()
    rgb_erased = rgb_erased * (1 - mask_tensor) # erase rgb
    rgb_erased = rgb_erased.to(torch.float32)
    
    comp_img = fcf_inpaint(G=model, org_img=rgb.to(torch.float32), erased_img=rgb_erased.to(torch.float32), mask=mask_tensor.to(torch.float32))
    rgb_erased = denorm(rgb_erased)
    comp_img = denorm(comp_img)
    return comp_img

def run_app(model):
    
    if "button_id" not in st.session_state:
        st.session_state["button_id"] = ""
    if "color_to_label" not in st.session_state:
        st.session_state["color_to_label"] = {}
    image_inpainting(model)

    with st.sidebar:
        st.markdown("---")

def image_inpainting(model):
    if 'reuse_image' not in st.session_state:
        st.session_state.reuse_image = None
    
    st.title(title)
    st.markdown(article, unsafe_allow_html=True)
    st.markdown(description, unsafe_allow_html=True)

    image = st.sidebar.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])
    
    sample_image = st.sidebar.radio('Choose a Sample Image', [
            'scene-background.png',
            'fence-background.png',
            'bench.png',
            'house.png',
            'landscape.png',
            'truck.png',
            'scenery.png',
            'grass-texture.png',
            'mapview-texture.png',
        ])
    
    drawing_mode = st.sidebar.selectbox(
    "Drawing tool:", ("freedraw", "line")
)

    image = Image.open(image).convert("RGBA") if image else Image.open(f"./test_512/{sample_image}").convert("RGBA")        
    image = image.resize((512, 512))
    width, height = image.size
    stroke_width = st.sidebar.slider("Stroke width: ", 1, 100, 20)

    canvas_result = st_canvas(
        stroke_color="rgba(255, 0, 255, 0.8)",
        stroke_width=stroke_width,
        background_image=image,
        height=height,
        width=width,
        drawing_mode=drawing_mode,
        key="canvas",
    )
    if canvas_result.image_data is not None and image and len(canvas_result.json_data["objects"]) > 0:
        
        im = canvas_result.image_data.copy()
        background = np.where(
            (im[:, :, 0] == 0) & 
            (im[:, :, 1] == 0) & 
            (im[:, :, 2] == 0)
        )
        drawing = np.where(
            (im[:, :, 0] == 255) & 
            (im[:, :, 1] == 0) & 
            (im[:, :, 2] == 255)
        )
        im[background]=[0,0,0,255]
        im[drawing]=[0,0,0,0] #RGBA
        if st.button('Run FcF-Inpainting'):
            col1, col2 = st.columns([1,1])
            with col1:
                # if st.button('Show Image with Holes'):
                st.write("Masked Image")
                mask_show = process_mask(np.array(image), np.array(im))
                st.image(mask_show)
            with col2:
                st.write("Inpainted Image")
                inpainted_img = inpaint(np.array(image), np.array(im), model)
                st.image(inpainted_img)

if __name__ == "__main__":
    st.set_page_config(
        page_title="FcF-Inpainting", page_icon=":sparkles:"
    )
    st.sidebar.subheader("Configuration")
    model = create_model("models/places_512.pkl")
    run_app(model)