Spaces:
Build error
Build error
from PIL import Image | |
import streamlit as st | |
from streamlit_drawable_canvas import st_canvas | |
from torchvision.transforms import ToTensor | |
import torch | |
import numpy as np | |
import cv2 | |
import aotgan.model.aotgan as net | |
def load_model(model_name): | |
model = net.InpaintGenerator.from_pretrained(model_name) | |
return model | |
def postprocess(image): | |
image = torch.clamp(image, -1., 1.) | |
image = (image + 1) / 2.0 * 255.0 | |
image = image.permute(1, 2, 0) | |
image = image.cpu().numpy().astype(np.uint8) | |
return image | |
def infer(img, mask): | |
with torch.no_grad(): | |
img_cv = cv2.resize(np.array(img), (512, 512)) # Fixing everything to 512 x 512 for this demo. | |
img_tensor = (ToTensor()(img_cv) * 2.0 - 1.0).unsqueeze(0) | |
mask_tensor = (ToTensor()(mask.astype(np.uint8))).unsqueeze(0) | |
masked_tensor = (img_tensor * (1 - mask_tensor).float()) + mask_tensor | |
pred_tensor = model(masked_tensor, mask_tensor) | |
comp_tensor = (pred_tensor * mask_tensor + img_tensor * (1 - mask_tensor)) | |
comp_np = postprocess(comp_tensor[0]) | |
return comp_np | |
stroke_width = 8 | |
stroke_color = "#FFF" | |
bg_color = "#000" | |
bg_image = st.sidebar.file_uploader("Image:", type=["png", "jpg", "jpeg"]) | |
sample_bg_image = st.sidebar.radio('Sample Images', [ | |
"man.png", | |
"pexels-ike-louie-natividad-2709388.jpg", | |
"pexels-christina-morillo-1181686.jpg", | |
"pexels-italo-melo-2379005.jpg", | |
"rainbow.jpeg", | |
"kitty.jpg", | |
"kitty_on_chair.jpeg", | |
]) | |
drawing_mode = st.sidebar.selectbox( | |
"Drawing tool:", ("freedraw", "rect", "circle") | |
) | |
model_name = st.sidebar.selectbox( | |
"Select model:", ("NimaBoscarino/aot-gan-celebahq", "NimaBoscarino/aot-gan-places2") | |
) | |
model = load_model(model_name) | |
bg_image = Image.open(bg_image) if bg_image else Image.open(f"./pictures/{sample_bg_image}") | |
st.subheader("Draw on the image to erase features. The inpainted result will be generated and displayed below.") | |
canvas_result = st_canvas( | |
fill_color="rgb(255, 255, 255)", | |
stroke_width=stroke_width, | |
stroke_color=stroke_color, | |
background_color=bg_color, | |
background_image=bg_image, | |
update_streamlit=True, | |
height=512, | |
width=512, | |
drawing_mode=drawing_mode, | |
key="canvas", | |
) | |
if canvas_result.image_data is not None and bg_image and len(canvas_result.json_data["objects"]) > 0: | |
result = infer(bg_image, canvas_result.image_data[:, :, 3]) | |
st.image(result) | |