NimaBoscarino's picture
Create Streamlit demo
4a51a01
from PIL import Image
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from torchvision.transforms import ToTensor
import torch
import numpy as np
import cv2
import aotgan.model.aotgan as net
@st.cache
def load_model(model_name):
model = net.InpaintGenerator.from_pretrained(model_name)
return model
def postprocess(image):
image = torch.clamp(image, -1., 1.)
image = (image + 1) / 2.0 * 255.0
image = image.permute(1, 2, 0)
image = image.cpu().numpy().astype(np.uint8)
return image
def infer(img, mask):
with torch.no_grad():
img_cv = cv2.resize(np.array(img), (512, 512)) # Fixing everything to 512 x 512 for this demo.
img_tensor = (ToTensor()(img_cv) * 2.0 - 1.0).unsqueeze(0)
mask_tensor = (ToTensor()(mask.astype(np.uint8))).unsqueeze(0)
masked_tensor = (img_tensor * (1 - mask_tensor).float()) + mask_tensor
pred_tensor = model(masked_tensor, mask_tensor)
comp_tensor = (pred_tensor * mask_tensor + img_tensor * (1 - mask_tensor))
comp_np = postprocess(comp_tensor[0])
return comp_np
stroke_width = 8
stroke_color = "#FFF"
bg_color = "#000"
bg_image = st.sidebar.file_uploader("Image:", type=["png", "jpg", "jpeg"])
sample_bg_image = st.sidebar.radio('Sample Images', [
"man.png",
"pexels-ike-louie-natividad-2709388.jpg",
"pexels-christina-morillo-1181686.jpg",
"pexels-italo-melo-2379005.jpg",
"rainbow.jpeg",
"kitty.jpg",
"kitty_on_chair.jpeg",
])
drawing_mode = st.sidebar.selectbox(
"Drawing tool:", ("freedraw", "rect", "circle")
)
model_name = st.sidebar.selectbox(
"Select model:", ("NimaBoscarino/aot-gan-celebahq", "NimaBoscarino/aot-gan-places2")
)
model = load_model(model_name)
bg_image = Image.open(bg_image) if bg_image else Image.open(f"./pictures/{sample_bg_image}")
st.subheader("Draw on the image to erase features. The inpainted result will be generated and displayed below.")
canvas_result = st_canvas(
fill_color="rgb(255, 255, 255)",
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color=bg_color,
background_image=bg_image,
update_streamlit=True,
height=512,
width=512,
drawing_mode=drawing_mode,
key="canvas",
)
if canvas_result.image_data is not None and bg_image and len(canvas_result.json_data["objects"]) > 0:
result = infer(bg_image, canvas_result.image_data[:, :, 3])
st.image(result)