import gradio as gr from io import BytesIO import requests import PIL from PIL import Image import numpy as np import os import uuid import torch from torch import autocast import cv2 from matplotlib import pyplot as plt from torchvision import transforms from diffusers import DiffusionPipeline from PIL import Image, ImageOps import requests from io import BytesIO from transparent_background import Remover def resize_with_padding(img, expected_size): img.thumbnail((expected_size[0], expected_size[1])) delta_width = expected_size[0] - img.size[0] delta_height = expected_size[1] - img.size[1] pad_width = delta_width // 2 pad_height = delta_height // 2 padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height) return ImageOps.expand(img, padding) bird_image = Image.open('bird.jpeg').convert('RGB') bird_controlnet = Image.open('bird-controlnet.webp').convert('RGB') bird_sd2 = Image.open('bird-sd2.webp').convert('RGB') bird_mask = Image.open('bird-mask.webp').convert('RGB') device = 'cuda' # Load background detection model remover = Remover() # default setting remover = Remover(mode='base') pipe = DiffusionPipeline.from_pretrained("yahoo-inc/photo-background-generation", custom_pipeline="yahoo-inc/photo-background-generation").to(device) def read_content(file_path: str) -> str: """read the content of target file """ with open(file_path, 'r', encoding='utf-8') as f: content = f.read() return content def predict(img, prompt="", seed=0): img = img.convert("RGB") img = resize_with_padding(img, (512, 512)) mask = remover.process(img, type='map') mask = ImageOps.invert(mask) with torch.autocast("cuda"): generator = torch.Generator(device='cuda').manual_seed(seed) output_controlnet = pipe(generator=generator, prompt=prompt, image=img, mask_image=mask, control_image=mask, num_images_per_prompt=1, num_inference_steps=20, guess_mode=False, controlnet_conditioning_scale=1.0, guidance_scale=7.5).images[0] generator = torch.Generator(device='cuda').manual_seed(seed) output_sd2 = pipe(generator=generator, prompt=prompt, image=img, mask_image=mask, control_image=mask, num_images_per_prompt=1, num_inference_steps=20, guess_mode=False, controlnet_conditioning_scale=0.0, guidance_scale=7.5).images[0] torch.cuda.empty_cache() return output_controlnet, output_sd2, mask css = ''' .container {max-width: 1150px;margin: auto;padding-top: 1.5rem} #image_upload{min-height:400px} #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 512px} #mask_radio .gr-form{background:transparent; border: none} #word_mask{margin-top: .75em !important} #word_mask textarea:disabled{opacity: 0.3} .footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5} .footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white} .dark .footer {border-color: #303030} .dark .footer>p {background: #0b0f19} .acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%} #image_upload .touch-none{display: flex} @keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } } #share-btn-container { display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; } #share-btn { all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important; } #share-btn * { all: unset; } #share-btn-container div:nth-child(-n+2){ width: auto !important; min-height: 0px !important; } #share-btn-container .wrap { display: none !important; } ''' image_blocks = gr.Blocks(css=css) with image_blocks as demo: gr.HTML(read_content("header.html")) with gr.Group(): with gr.Row(variant='compact', equal_height=True, ): with gr.Column(variant='compact', ): image = gr.Image(value=bird_image, sources=['upload'], elem_id="image_upload", type="pil", label="Upload an image", width=512, height=512) with gr.Row(variant='compact', elem_id="prompt-container", equal_height=True): prompt = gr.Textbox(label='prompt', placeholder = 'What you want in the background?', show_label=True, elem_id="input-text") seed = gr.Number(label="seed", value=13) btn = gr.Button("Generate Background!") with gr.Column(variant='compact', ): controlnet_out = gr.Image(value=bird_controlnet, label="SD2+ControlNet (Ours) Output", elem_id="output-controlnet", width=512, height=512) with gr.Row(variant='compact', equal_height=True, ): with gr.Column(variant='compact', ): mask_out = gr.Image(value=bird_mask, label="Background Mask", elem_id="output-mask", width=512, height=512) with gr.Column(variant='compact', ): sd2_out = gr.Image(value=bird_sd2, label="SD2 Output", elem_id="output-sd2", width=512, height=512) btn.click(fn=predict, inputs=[image, prompt, seed], outputs=[controlnet_out, sd2_out, mask_out ]) image_blocks.launch()