|
from PIL import Image |
|
import streamlit as st |
|
from streamlit_drawable_canvas import st_canvas |
|
from torchvision.transforms import ToTensor |
|
import torch |
|
import numpy as np |
|
import cv2 |
|
import aotgan.model.aotgan as net |
|
|
|
@st.cache |
|
def load_model(model_name): |
|
model = net.InpaintGenerator.from_pretrained(model_name) |
|
return model |
|
|
|
def postprocess(image): |
|
image = torch.clamp(image, -1., 1.) |
|
image = (image + 1) / 2.0 * 255.0 |
|
image = image.permute(1, 2, 0) |
|
image = image.cpu().numpy().astype(np.uint8) |
|
return image |
|
|
|
def infer(img, mask): |
|
with torch.no_grad(): |
|
img_cv = cv2.resize(np.array(img), (512, 512)) |
|
img_tensor = (ToTensor()(img_cv) * 2.0 - 1.0).unsqueeze(0) |
|
mask_tensor = (ToTensor()(mask.astype(np.uint8))).unsqueeze(0) |
|
masked_tensor = (img_tensor * (1 - mask_tensor).float()) + mask_tensor |
|
pred_tensor = model(masked_tensor, mask_tensor) |
|
comp_tensor = (pred_tensor * mask_tensor + img_tensor * (1 - mask_tensor)) |
|
comp_np = postprocess(comp_tensor[0]) |
|
|
|
return comp_np |
|
|
|
stroke_width = 8 |
|
stroke_color = "#FFF" |
|
bg_color = "#000" |
|
bg_image = st.sidebar.file_uploader("Image:", type=["png", "jpg", "jpeg"]) |
|
sample_bg_image = st.sidebar.radio('Sample Images', [ |
|
"man.png", |
|
"pexels-ike-louie-natividad-2709388.jpg", |
|
"pexels-christina-morillo-1181686.jpg", |
|
"pexels-italo-melo-2379005.jpg", |
|
"rainbow.jpeg", |
|
"kitty.jpg", |
|
"kitty_on_chair.jpeg", |
|
]) |
|
drawing_mode = st.sidebar.selectbox( |
|
"Drawing tool:", ("freedraw", "rect", "circle") |
|
) |
|
|
|
model_name = st.sidebar.selectbox( |
|
"Select model:", ("NimaBoscarino/aot-gan-celebahq", "NimaBoscarino/aot-gan-places2") |
|
) |
|
model = load_model(model_name) |
|
|
|
bg_image = Image.open(bg_image) if bg_image else Image.open(f"./pictures/{sample_bg_image}") |
|
|
|
st.subheader("Draw on the image to erase features. The inpainted result will be generated and displayed below.") |
|
canvas_result = st_canvas( |
|
fill_color="rgb(255, 255, 255)", |
|
stroke_width=stroke_width, |
|
stroke_color=stroke_color, |
|
background_color=bg_color, |
|
background_image=bg_image, |
|
update_streamlit=True, |
|
height=512, |
|
width=512, |
|
drawing_mode=drawing_mode, |
|
key="canvas", |
|
) |
|
|
|
if canvas_result.image_data is not None and bg_image and len(canvas_result.json_data["objects"]) > 0: |
|
result = infer(bg_image, canvas_result.image_data[:, :, 3]) |
|
st.image(result) |
|
|