|
import gradio as gr |
|
|
|
from io import BytesIO |
|
import requests |
|
import PIL |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
import uuid |
|
import torch |
|
from torch import autocast |
|
import cv2 |
|
from matplotlib import pyplot as plt |
|
from torchvision import transforms |
|
from diffusers import DiffusionPipeline |
|
|
|
from PIL import Image, ImageOps |
|
import requests |
|
from io import BytesIO |
|
from transparent_background import Remover |
|
|
|
def resize_with_padding(img, expected_size): |
|
img.thumbnail((expected_size[0], expected_size[1])) |
|
delta_width = expected_size[0] - img.size[0] |
|
delta_height = expected_size[1] - img.size[1] |
|
pad_width = delta_width // 2 |
|
pad_height = delta_height // 2 |
|
padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height) |
|
return ImageOps.expand(img, padding) |
|
|
|
bird_image = Image.open('bird.jpeg').convert('RGB') |
|
bird_controlnet = Image.open('bird-controlnet.webp').convert('RGB') |
|
bird_sd2 = Image.open('bird-sd2.webp').convert('RGB') |
|
bird_mask = Image.open('bird-mask.webp').convert('RGB') |
|
|
|
device = 'cuda' |
|
|
|
remover = Remover() |
|
remover = Remover(mode='base') |
|
|
|
pipe = DiffusionPipeline.from_pretrained("yahoo-inc/photo-background-generation", custom_pipeline="yahoo-inc/photo-background-generation").to(device) |
|
|
|
def read_content(file_path: str) -> str: |
|
"""read the content of target file |
|
""" |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
|
|
return content |
|
|
|
def predict(img, prompt="", seed=0): |
|
img = img.convert("RGB") |
|
img = resize_with_padding(img, (512, 512)) |
|
mask = remover.process(img, type='map') |
|
mask = ImageOps.invert(mask) |
|
with torch.autocast("cuda"): |
|
generator = torch.Generator(device='cuda').manual_seed(seed) |
|
output_controlnet = pipe(generator=generator, prompt=prompt, image=img, mask_image=mask, control_image=mask, num_images_per_prompt=1, num_inference_steps=20, guess_mode=False, controlnet_conditioning_scale=1.0, guidance_scale=7.5).images[0] |
|
generator = torch.Generator(device='cuda').manual_seed(seed) |
|
output_sd2 = pipe(generator=generator, prompt=prompt, image=img, mask_image=mask, control_image=mask, num_images_per_prompt=1, num_inference_steps=20, guess_mode=False, controlnet_conditioning_scale=0.0, guidance_scale=7.5).images[0] |
|
torch.cuda.empty_cache() |
|
return output_controlnet, output_sd2, mask |
|
|
|
css = ''' |
|
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem} |
|
#image_upload{min-height:400px} |
|
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 512px} |
|
#mask_radio .gr-form{background:transparent; border: none} |
|
#word_mask{margin-top: .75em !important} |
|
#word_mask textarea:disabled{opacity: 0.3} |
|
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5} |
|
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white} |
|
.dark .footer {border-color: #303030} |
|
.dark .footer>p {background: #0b0f19} |
|
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%} |
|
#image_upload .touch-none{display: flex} |
|
@keyframes spin { |
|
from { |
|
transform: rotate(0deg); |
|
} |
|
to { |
|
transform: rotate(360deg); |
|
} |
|
} |
|
#share-btn-container { |
|
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; |
|
} |
|
#share-btn { |
|
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important; |
|
} |
|
#share-btn * { |
|
all: unset; |
|
} |
|
#share-btn-container div:nth-child(-n+2){ |
|
width: auto !important; |
|
min-height: 0px !important; |
|
} |
|
#share-btn-container .wrap { |
|
display: none !important; |
|
} |
|
''' |
|
|
|
image_blocks = gr.Blocks(css=css) |
|
with image_blocks as demo: |
|
gr.HTML(read_content("header.html")) |
|
with gr.Group(): |
|
with gr.Row(variant='compact', equal_height=True, ): |
|
with gr.Column(variant='compact', ): |
|
image = gr.Image(value=bird_image, sources=['upload'], elem_id="image_upload", type="pil", label="Upload an image", width=512, height=512) |
|
with gr.Row(variant='compact', elem_id="prompt-container", equal_height=True): |
|
prompt = gr.Textbox(label='prompt', placeholder = 'What you want in the background?', show_label=True, elem_id="input-text") |
|
seed = gr.Number(label="seed", value=13) |
|
btn = gr.Button("Generate Background!") |
|
with gr.Column(variant='compact', ): |
|
controlnet_out = gr.Image(value=bird_controlnet, label="SD2+ControlNet (Ours) Output", elem_id="output-controlnet", width=512, height=512) |
|
|
|
with gr.Row(variant='compact', equal_height=True, ): |
|
with gr.Column(variant='compact', ): |
|
mask_out = gr.Image(value=bird_mask, label="Background Mask", elem_id="output-mask", width=512, height=512) |
|
with gr.Column(variant='compact', ): |
|
sd2_out = gr.Image(value=bird_sd2, label="SD2 Output", elem_id="output-sd2", width=512, height=512) |
|
btn.click(fn=predict, inputs=[image, prompt, seed], outputs=[controlnet_out, sd2_out, mask_out ]) |
|
|
|
|
|
|
|
image_blocks.launch() |
|
|