Update app.py
Browse files
app.py
CHANGED
@@ -14,36 +14,27 @@ from matplotlib import pyplot as plt
|
|
14 |
from torchvision import transforms
|
15 |
from diffusers import DiffusionPipeline
|
16 |
from diffusers.utils import torch_device
|
|
|
|
|
17 |
pipe = DiffusionPipeline.from_pretrained(
|
18 |
"Fantasy-Studio/Paint-by-Example",
|
19 |
-
torch_dtype=torch.
|
20 |
)
|
21 |
-
pipe = pipe.to("cuda")
|
22 |
-
|
23 |
-
from share_btn import community_icon_html, loading_icon_html, share_js
|
24 |
-
|
25 |
-
def read_content(file_path: str) -> str:
|
26 |
-
"""read the content of target file
|
27 |
-
"""
|
28 |
-
with open(file_path, 'r', encoding='utf-8') as f:
|
29 |
-
content = f.read()
|
30 |
-
|
31 |
-
return content
|
32 |
|
|
|
33 |
def predict(dict, reference, scale, seed, step):
|
34 |
-
width,height=dict["image"].size
|
35 |
-
if width<height:
|
36 |
-
factor=width/512.0
|
37 |
-
width=512
|
38 |
-
height=int((height/factor)/8.0)*8
|
39 |
-
|
40 |
else:
|
41 |
-
factor=height/512.0
|
42 |
-
height=512
|
43 |
-
width=int((width/factor)/8.0)*8
|
44 |
-
init_image = dict["image"].convert("RGB").resize((width,height))
|
45 |
-
mask = dict["mask"].convert("RGB").resize((width,height))
|
46 |
-
generator = torch.Generator(
|
47 |
output = pipe(
|
48 |
image=init_image,
|
49 |
mask_image=mask,
|
@@ -52,9 +43,12 @@ def predict(dict, reference, scale, seed, step):
|
|
52 |
guidance_scale=scale,
|
53 |
num_inference_steps=step,
|
54 |
).images[0]
|
55 |
-
return output, gr.update(visible=True), gr.update(visible=True), gr.update(
|
|
|
|
|
56 |
|
57 |
|
|
|
58 |
css = '''
|
59 |
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
|
60 |
#image_upload{min-height:400px}
|
@@ -93,15 +87,28 @@ css = '''
|
|
93 |
display: none !important;
|
94 |
}
|
95 |
'''
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
ref_list.sort()
|
101 |
-
image_list=[os.path.join(image_dir,file) for file in os.listdir(image_dir)]
|
102 |
image_list.sort()
|
103 |
|
104 |
|
|
|
105 |
image_blocks = gr.Blocks(css=css)
|
106 |
with image_blocks as demo:
|
107 |
gr.HTML(read_content("header.html"))
|
@@ -114,8 +121,8 @@ with image_blocks as demo:
|
|
114 |
|
115 |
with gr.Column():
|
116 |
image_out = gr.Image(label="Output", elem_id="output-img").style(height=400)
|
117 |
-
guidance = gr.Slider(label="Guidance scale", value=5, maximum=15,interactive=True)
|
118 |
-
steps = gr.Slider(label="Steps", value=50, minimum=2, maximum=75, step=1,interactive=True)
|
119 |
|
120 |
seed = gr.Slider(0, 10000, label='Seed (0 = random)', value=0, step=1)
|
121 |
|
@@ -129,19 +136,17 @@ with image_blocks as demo:
|
|
129 |
community_icon = gr.HTML(community_icon_html, visible=True)
|
130 |
loading_icon = gr.HTML(loading_icon_html, visible=True)
|
131 |
share_button = gr.Button("Share to community", elem_id="share-btn", visible=True)
|
132 |
-
|
133 |
-
|
134 |
with gr.Row():
|
135 |
with gr.Column():
|
136 |
gr.Examples(image_list, inputs=[image],label="Examples - Source Image",examples_per_page=12)
|
137 |
with gr.Column():
|
138 |
gr.Examples(ref_list, inputs=[reference],label="Examples - Reference Image",examples_per_page=12)
|
139 |
-
|
140 |
btn.click(fn=predict, inputs=[image, reference, guidance, seed, steps], outputs=[image_out, community_icon, loading_icon, share_button])
|
141 |
share_button.click(None, [], [], _js=share_js)
|
142 |
|
143 |
-
|
144 |
-
|
145 |
gr.HTML(
|
146 |
"""
|
147 |
<div class="footer">
|
@@ -154,4 +159,5 @@ with image_blocks as demo:
|
|
154 |
"""
|
155 |
)
|
156 |
|
157 |
-
|
|
|
|
14 |
from torchvision import transforms
|
15 |
from diffusers import DiffusionPipeline
|
16 |
from diffusers.utils import torch_device
|
17 |
+
|
18 |
+
# Load the model
|
19 |
pipe = DiffusionPipeline.from_pretrained(
|
20 |
"Fantasy-Studio/Paint-by-Example",
|
21 |
+
torch_dtype=torch.float32, # Change to float32 for CPU
|
22 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
# Define function to predict
|
25 |
def predict(dict, reference, scale, seed, step):
|
26 |
+
width, height = dict["image"].size
|
27 |
+
if width < height:
|
28 |
+
factor = width / 512.0
|
29 |
+
width = 512
|
30 |
+
height = int((height / factor) / 8.0) * 8
|
|
|
31 |
else:
|
32 |
+
factor = height / 512.0
|
33 |
+
height = 512
|
34 |
+
width = int((width / factor) / 8.0) * 8
|
35 |
+
init_image = dict["image"].convert("RGB").resize((width, height))
|
36 |
+
mask = dict["mask"].convert("RGB").resize((width, height))
|
37 |
+
generator = torch.Generator().manual_seed(seed) if seed != 0 else None
|
38 |
output = pipe(
|
39 |
image=init_image,
|
40 |
mask_image=mask,
|
|
|
43 |
guidance_scale=scale,
|
44 |
num_inference_steps=step,
|
45 |
).images[0]
|
46 |
+
return output, gr.update(visible=True), gr.update(visible=True), gr.update(
|
47 |
+
visible=True
|
48 |
+
)
|
49 |
|
50 |
|
51 |
+
# Define CSS
|
52 |
css = '''
|
53 |
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
|
54 |
#image_upload{min-height:400px}
|
|
|
87 |
display: none !important;
|
88 |
}
|
89 |
'''
|
90 |
+
|
91 |
+
# Read content function
|
92 |
+
def read_content(file_path: str) -> str:
|
93 |
+
"""read the content of target file
|
94 |
+
"""
|
95 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
96 |
+
content = f.read()
|
97 |
+
|
98 |
+
return content
|
99 |
+
|
100 |
+
|
101 |
+
# Define example data
|
102 |
+
example = {}
|
103 |
+
ref_dir = 'examples/reference'
|
104 |
+
image_dir = 'examples/image'
|
105 |
+
ref_list = [os.path.join(ref_dir, file) for file in os.listdir(ref_dir)]
|
106 |
ref_list.sort()
|
107 |
+
image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)]
|
108 |
image_list.sort()
|
109 |
|
110 |
|
111 |
+
# Create Gradio Blocks instance
|
112 |
image_blocks = gr.Blocks(css=css)
|
113 |
with image_blocks as demo:
|
114 |
gr.HTML(read_content("header.html"))
|
|
|
121 |
|
122 |
with gr.Column():
|
123 |
image_out = gr.Image(label="Output", elem_id="output-img").style(height=400)
|
124 |
+
guidance = gr.Slider(label="Guidance scale", value=5, maximum=15, interactive=True)
|
125 |
+
steps = gr.Slider(label="Steps", value=50, minimum=2, maximum=75, step=1, interactive=True)
|
126 |
|
127 |
seed = gr.Slider(0, 10000, label='Seed (0 = random)', value=0, step=1)
|
128 |
|
|
|
136 |
community_icon = gr.HTML(community_icon_html, visible=True)
|
137 |
loading_icon = gr.HTML(loading_icon_html, visible=True)
|
138 |
share_button = gr.Button("Share to community", elem_id="share-btn", visible=True)
|
139 |
+
|
140 |
+
|
141 |
with gr.Row():
|
142 |
with gr.Column():
|
143 |
gr.Examples(image_list, inputs=[image],label="Examples - Source Image",examples_per_page=12)
|
144 |
with gr.Column():
|
145 |
gr.Examples(ref_list, inputs=[reference],label="Examples - Reference Image",examples_per_page=12)
|
146 |
+
|
147 |
btn.click(fn=predict, inputs=[image, reference, guidance, seed, steps], outputs=[image_out, community_icon, loading_icon, share_button])
|
148 |
share_button.click(None, [], [], _js=share_js)
|
149 |
|
|
|
|
|
150 |
gr.HTML(
|
151 |
"""
|
152 |
<div class="footer">
|
|
|
159 |
"""
|
160 |
)
|
161 |
|
162 |
+
# Launch the Gradio interface
|
163 |
+
image_blocks.launch()
|