erfaneshrati commited on
Commit
630e75f
β€’
1 Parent(s): f31f176

initial commit

Browse files
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: Photo Background Generation
3
- emoji: 🏒
4
- colorFrom: gray
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.36.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
  title: Photo Background Generation
3
+ emoji: πŸŒ–
4
+ colorFrom: blue
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.29.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
__pycache__/app.cpython-38.pyc ADDED
Binary file (5.01 kB). View file
 
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from io import BytesIO
4
+ import requests
5
+ import PIL
6
+ from PIL import Image
7
+ import numpy as np
8
+ import os
9
+ import uuid
10
+ import torch
11
+ from torch import autocast
12
+ import cv2
13
+ from matplotlib import pyplot as plt
14
+ from torchvision import transforms
15
+ from diffusers import DiffusionPipeline
16
+
17
+ from PIL import Image, ImageOps
18
+ import requests
19
+ from io import BytesIO
20
+ from transparent_background import Remover
21
+
22
+ def resize_with_padding(img, expected_size):
23
+ img.thumbnail((expected_size[0], expected_size[1]))
24
+ delta_width = expected_size[0] - img.size[0]
25
+ delta_height = expected_size[1] - img.size[1]
26
+ pad_width = delta_width // 2
27
+ pad_height = delta_height // 2
28
+ padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height)
29
+ return ImageOps.expand(img, padding)
30
+
31
+ bird_image = Image.open('bird.jpeg').convert('RGB')
32
+ bird_controlnet = Image.open('bird-controlnet.webp').convert('RGB')
33
+ bird_sd2 = Image.open('bird-sd2.webp').convert('RGB')
34
+ bird_mask = Image.open('bird-mask.webp').convert('RGB')
35
+
36
+ device = 'cuda'
37
+ # Load background detection model
38
+ remover = Remover() # default setting
39
+ remover = Remover(mode='base')
40
+
41
+ pipe = DiffusionPipeline.from_pretrained("yahoo-inc/photo-background-generation", custom_pipeline="yahoo-inc/photo-background-generation").to(device)
42
+
43
+ def read_content(file_path: str) -> str:
44
+ """read the content of target file
45
+ """
46
+ with open(file_path, 'r', encoding='utf-8') as f:
47
+ content = f.read()
48
+
49
+ return content
50
+
51
+ def predict(img, prompt="", seed=0):
52
+ img = img.convert("RGB")
53
+ img = resize_with_padding(img, (512, 512))
54
+ mask = remover.process(img, type='map')
55
+ mask = ImageOps.invert(mask)
56
+ with torch.autocast("cuda"):
57
+ generator = torch.Generator(device='cuda').manual_seed(seed)
58
+ output_controlnet = pipe(generator=generator, prompt=prompt, image=img, mask_image=mask, control_image=mask, num_images_per_prompt=1, num_inference_steps=20, guess_mode=False, controlnet_conditioning_scale=1.0, guidance_scale=7.5).images[0]
59
+ generator = torch.Generator(device='cuda').manual_seed(seed)
60
+ output_sd2 = pipe(generator=generator, prompt=prompt, image=img, mask_image=mask, control_image=mask, num_images_per_prompt=1, num_inference_steps=20, guess_mode=False, controlnet_conditioning_scale=0.0, guidance_scale=7.5).images[0]
61
+ torch.cuda.empty_cache()
62
+ return output_controlnet, output_sd2, mask
63
+
64
+ css = '''
65
+ .container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
66
+ #image_upload{min-height:400px}
67
+ #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 512px}
68
+ #mask_radio .gr-form{background:transparent; border: none}
69
+ #word_mask{margin-top: .75em !important}
70
+ #word_mask textarea:disabled{opacity: 0.3}
71
+ .footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
72
+ .footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
73
+ .dark .footer {border-color: #303030}
74
+ .dark .footer>p {background: #0b0f19}
75
+ .acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
76
+ #image_upload .touch-none{display: flex}
77
+ @keyframes spin {
78
+ from {
79
+ transform: rotate(0deg);
80
+ }
81
+ to {
82
+ transform: rotate(360deg);
83
+ }
84
+ }
85
+ #share-btn-container {
86
+ display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
87
+ }
88
+ #share-btn {
89
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
90
+ }
91
+ #share-btn * {
92
+ all: unset;
93
+ }
94
+ #share-btn-container div:nth-child(-n+2){
95
+ width: auto !important;
96
+ min-height: 0px !important;
97
+ }
98
+ #share-btn-container .wrap {
99
+ display: none !important;
100
+ }
101
+ '''
102
+
103
+ image_blocks = gr.Blocks(css=css)
104
+ with image_blocks as demo:
105
+ gr.HTML(read_content("header.html"))
106
+ with gr.Group():
107
+ with gr.Row(variant='compact', equal_height=True, ):
108
+ with gr.Column(variant='compact', ):
109
+ image = gr.Image(value=bird_image, sources=['upload'], elem_id="image_upload", type="pil", label="Upload an image", width=512, height=512)
110
+ with gr.Row(variant='compact', elem_id="prompt-container", equal_height=True):
111
+ prompt = gr.Textbox(label='prompt', placeholder = 'What you want in the background?', show_label=True, elem_id="input-text")
112
+ seed = gr.Number(label="seed", value=13)
113
+ btn = gr.Button("Generate Background!")
114
+ with gr.Column(variant='compact', ):
115
+ controlnet_out = gr.Image(value=bird_controlnet, label="SD2+ControlNet (Ours) Output", elem_id="output-controlnet", width=512, height=512)
116
+
117
+ with gr.Row(variant='compact', equal_height=True, ):
118
+ with gr.Column(variant='compact', ):
119
+ mask_out = gr.Image(value=bird_mask, label="Background Mask", elem_id="output-mask", width=512, height=512)
120
+ with gr.Column(variant='compact', ):
121
+ sd2_out = gr.Image(value=bird_sd2, label="SD2 Output", elem_id="output-sd2", width=512, height=512)
122
+ btn.click(fn=predict, inputs=[image, prompt, seed], outputs=[controlnet_out, sd2_out, mask_out ])
123
+
124
+
125
+
126
+ image_blocks.launch()
bird-controlnet.webp ADDED
bird-mask.webp ADDED
bird-sd2.webp ADDED
bird.jpeg ADDED
header.html ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
2
+ <div style="
3
+ display: inline-flex;
4
+ gap: 0.8rem;
5
+ font-size: 1.75rem;
6
+ justify-content: center;
7
+ margin-bottom: 10px;
8
+ ">
9
+ <h1 style="font-weight: 900; align-items: center; margin-bottom: 7px; margin-top: 20px;">
10
+ Text-guided Background Generation for Salient Objects 🎨
11
+ </h1>
12
+ </div>
13
+ <div>
14
+ <p style="align-items: center; margin-bottom: 7px;">
15
+ Create a new background for an image with a visible salient object using a text prompt. This space demos the "object expansion" issue when using inpainting models for background generation and how it can be fixed using <a href="https://huggingface.co/yahoo-inc/photo-background-generation">photo-background-generation</a> model. We use <a href="https://pypi.org/project/transparent-background/">transparent-background</a> to obtain the foreground mask. The research paper of this work: <a href="https://arxiv.org/abs/2404.10157">Arxiv</a>
16
+ </p>
17
+ </div>
18
+ </div>
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch
3
+ torchvision
4
+ diffusers
5
+ transformers
6
+ ftfy
7
+ numpy
8
+ matplotlib
9
+ uuid
10
+ opencv-python
11
+ git+https://github.com/openai/CLIP.git
12
+ transparent-background
13
+ accelerate