File size: 2,456 Bytes
4a51a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from PIL import Image
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from torchvision.transforms import ToTensor
import torch
import numpy as np
import cv2
import aotgan.model.aotgan as net

@st.cache
def load_model(model_name):
    model = net.InpaintGenerator.from_pretrained(model_name)
    return model

def postprocess(image):
    image = torch.clamp(image, -1., 1.)
    image = (image + 1) / 2.0 * 255.0
    image = image.permute(1, 2, 0)
    image = image.cpu().numpy().astype(np.uint8)
    return image

def infer(img, mask):
    with torch.no_grad():
        img_cv = cv2.resize(np.array(img), (512, 512))  # Fixing everything to 512 x 512 for this demo.
        img_tensor = (ToTensor()(img_cv) * 2.0 - 1.0).unsqueeze(0)
        mask_tensor = (ToTensor()(mask.astype(np.uint8))).unsqueeze(0)
        masked_tensor = (img_tensor * (1 - mask_tensor).float()) + mask_tensor
        pred_tensor = model(masked_tensor, mask_tensor)
        comp_tensor = (pred_tensor * mask_tensor + img_tensor * (1 - mask_tensor))
        comp_np = postprocess(comp_tensor[0])

        return comp_np

stroke_width = 8
stroke_color = "#FFF"
bg_color = "#000"
bg_image = st.sidebar.file_uploader("Image:", type=["png", "jpg", "jpeg"])
sample_bg_image = st.sidebar.radio('Sample Images', [
    "man.png",
    "pexels-ike-louie-natividad-2709388.jpg",
    "pexels-christina-morillo-1181686.jpg",
    "pexels-italo-melo-2379005.jpg",
    "rainbow.jpeg",
    "kitty.jpg",
    "kitty_on_chair.jpeg",
])
drawing_mode = st.sidebar.selectbox(
    "Drawing tool:", ("freedraw", "rect", "circle")
)

model_name = st.sidebar.selectbox(
    "Select model:", ("NimaBoscarino/aot-gan-celebahq", "NimaBoscarino/aot-gan-places2")
)
model = load_model(model_name)

bg_image = Image.open(bg_image) if bg_image else Image.open(f"./pictures/{sample_bg_image}")

st.subheader("Draw on the image to erase features. The inpainted result will be generated and displayed below.")
canvas_result = st_canvas(
    fill_color="rgb(255, 255, 255)",
    stroke_width=stroke_width,
    stroke_color=stroke_color,
    background_color=bg_color,
    background_image=bg_image,
    update_streamlit=True,
    height=512,
    width=512,
    drawing_mode=drawing_mode,
    key="canvas",
)
    
if canvas_result.image_data is not None and bg_image and len(canvas_result.json_data["objects"]) > 0:
    result = infer(bg_image, canvas_result.image_data[:, :, 3])
    st.image(result)