Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -61,7 +61,6 @@ ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
|
|
61 |
load_state_dict(model, ckpt, strict=True)
|
62 |
model.freeze()
|
63 |
|
64 |
-
# Check if CUDA is available, otherwise use CPU
|
65 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
66 |
model.to(device)
|
67 |
|
@@ -85,27 +84,26 @@ def process(
|
|
85 |
vae_encoder_tile_size: int,
|
86 |
vae_decoder_tile_size: int
|
87 |
):
|
88 |
-
print(
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
)
|
97 |
pl.seed_everything(seed)
|
98 |
|
99 |
-
# Resize
|
100 |
if sr_scale != 1:
|
101 |
control_img = control_img.resize(
|
102 |
tuple(math.ceil(x * sr_scale) for x in control_img.size),
|
103 |
Image.BICUBIC
|
104 |
)
|
105 |
-
|
106 |
input_size = control_img.size
|
107 |
|
108 |
-
# Resize the
|
109 |
if not tile_diffusion:
|
110 |
control_img = auto_resize(control_img, 512)
|
111 |
else:
|
@@ -129,39 +127,28 @@ def process(
|
|
129 |
shape = (1, 4, height // 8, width // 8)
|
130 |
x_T = torch.randn(shape, device=device, dtype=torch.float32)
|
131 |
|
132 |
-
# Modify the get_learned_conditioning method to handle the attention mask issue
|
133 |
-
def modified_get_learned_conditioning(model, prompt):
|
134 |
-
tokens = model.cond_stage_model.tokenizer.encode(prompt)
|
135 |
-
tokens = torch.LongTensor(tokens).to(model.device).unsqueeze(0)
|
136 |
-
encoder_hidden_states = model.cond_stage_model.transformer(input_ids=tokens).last_hidden_state
|
137 |
-
return encoder_hidden_states
|
138 |
-
|
139 |
-
cond = modified_get_learned_conditioning(model, positive_prompt)
|
140 |
-
uncond = modified_get_learned_conditioning(model, negative_prompt)
|
141 |
-
|
142 |
if not tile_diffusion and not tile_vae:
|
143 |
samples = sampler.sample_ccsr(
|
144 |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
|
145 |
-
positive_prompt=
|
146 |
cfg_scale=cfg_scale,
|
147 |
color_fix_type="adain" if use_color_fix else "none"
|
148 |
)
|
149 |
else:
|
150 |
if tile_vae:
|
151 |
-
|
152 |
-
pass
|
153 |
if tile_diffusion:
|
154 |
samples = sampler.sample_with_tile_ccsr(
|
155 |
tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride,
|
156 |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
|
157 |
-
positive_prompt=
|
158 |
cfg_scale=cfg_scale,
|
159 |
color_fix_type="adain" if use_color_fix else "none"
|
160 |
)
|
161 |
else:
|
162 |
samples = sampler.sample_ccsr(
|
163 |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
|
164 |
-
positive_prompt=
|
165 |
cfg_scale=cfg_scale,
|
166 |
color_fix_type="adain" if use_color_fix else "none"
|
167 |
)
|
@@ -180,12 +167,31 @@ def update_output_resolution(image, scale):
|
|
180 |
return f"Current resolution: {width}x{height}. Output resolution: {int(width*scale)}x{int(height*scale)}"
|
181 |
return "Upload an image to see the output resolution"
|
182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
# Improved UI design
|
184 |
css = """
|
185 |
.container {max-width: 1200px; margin: auto; padding: 20px;}
|
186 |
.input-image {width: 100%; max-height: 500px; object-fit: contain;}
|
187 |
.output-gallery {display: flex; flex-wrap: wrap; justify-content: center;}
|
188 |
.output-image {margin: 10px; max-width: 45%; height: auto;}
|
|
|
189 |
"""
|
190 |
|
191 |
with gr.Blocks(css=css) as block:
|
@@ -194,7 +200,20 @@ with gr.Blocks(css=css) as block:
|
|
194 |
with gr.Row():
|
195 |
with gr.Column(scale=1):
|
196 |
input_image = gr.Image(type="pil", label="Input Image", elem_classes="input-image")
|
197 |
-
sr_scale = gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
output_resolution = gr.Markdown("Upload an image to see the output resolution")
|
199 |
run_button = gr.Button(value="Run", variant="primary")
|
200 |
|
@@ -221,15 +240,43 @@ with gr.Blocks(css=css) as block:
|
|
221 |
with gr.Row():
|
222 |
result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery", elem_classes="output-gallery")
|
223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
inputs = [
|
225 |
input_image, num_samples, sr_scale, strength, positive_prompt, negative_prompt,
|
226 |
cfg_scale, steps, use_color_fix, seed, tile_diffusion, tile_diffusion_size,
|
227 |
tile_diffusion_stride, tile_vae, vae_encoder_tile_size, vae_decoder_tile_size,
|
228 |
]
|
229 |
-
run_button.click(
|
|
|
|
|
|
|
|
|
230 |
|
231 |
-
input_image.change(
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
input_image.change(
|
235 |
lambda x: gr.update(interactive=x is not None),
|
|
|
61 |
load_state_dict(model, ckpt, strict=True)
|
62 |
model.freeze()
|
63 |
|
|
|
64 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
65 |
model.to(device)
|
66 |
|
|
|
84 |
vae_encoder_tile_size: int,
|
85 |
vae_decoder_tile_size: int
|
86 |
):
|
87 |
+
print(f"control image shape={control_img.size}\n"
|
88 |
+
f"num_samples={num_samples}, sr_scale={sr_scale}, strength={strength}\n"
|
89 |
+
f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n"
|
90 |
+
f"cfg scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n"
|
91 |
+
f"seed={seed}\n"
|
92 |
+
f"tile_diffusion={tile_diffusion}, tile_diffusion_size={tile_diffusion_size}, tile_diffusion_stride={tile_diffusion_stride}"
|
93 |
+
f"tile_vae={tile_vae}, vae_encoder_tile_size={vae_encoder_tile_size}, vae_decoder_tile_size={vae_decoder_tile_size}")
|
94 |
+
|
|
|
95 |
pl.seed_everything(seed)
|
96 |
|
97 |
+
# Resize input image
|
98 |
if sr_scale != 1:
|
99 |
control_img = control_img.resize(
|
100 |
tuple(math.ceil(x * sr_scale) for x in control_img.size),
|
101 |
Image.BICUBIC
|
102 |
)
|
103 |
+
|
104 |
input_size = control_img.size
|
105 |
|
106 |
+
# Resize the image
|
107 |
if not tile_diffusion:
|
108 |
control_img = auto_resize(control_img, 512)
|
109 |
else:
|
|
|
127 |
shape = (1, 4, height // 8, width // 8)
|
128 |
x_T = torch.randn(shape, device=device, dtype=torch.float32)
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
if not tile_diffusion and not tile_vae:
|
131 |
samples = sampler.sample_ccsr(
|
132 |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
|
133 |
+
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
|
134 |
cfg_scale=cfg_scale,
|
135 |
color_fix_type="adain" if use_color_fix else "none"
|
136 |
)
|
137 |
else:
|
138 |
if tile_vae:
|
139 |
+
model._init_tiled_vae(encoder_tile_size=vae_encoder_tile_size, decoder_tile_size=vae_decoder_tile_size)
|
|
|
140 |
if tile_diffusion:
|
141 |
samples = sampler.sample_with_tile_ccsr(
|
142 |
tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride,
|
143 |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
|
144 |
+
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
|
145 |
cfg_scale=cfg_scale,
|
146 |
color_fix_type="adain" if use_color_fix else "none"
|
147 |
)
|
148 |
else:
|
149 |
samples = sampler.sample_ccsr(
|
150 |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
|
151 |
+
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
|
152 |
cfg_scale=cfg_scale,
|
153 |
color_fix_type="adain" if use_color_fix else "none"
|
154 |
)
|
|
|
167 |
return f"Current resolution: {width}x{height}. Output resolution: {int(width*scale)}x{int(height*scale)}"
|
168 |
return "Upload an image to see the output resolution"
|
169 |
|
170 |
+
def update_scale_choices(image):
|
171 |
+
if image is not None:
|
172 |
+
width, height = image.size
|
173 |
+
aspect_ratio = width / height
|
174 |
+
common_resolutions = [
|
175 |
+
(1280, 720), (1920, 1080), (2560, 1440), (3840, 2160), # 16:9
|
176 |
+
(1440, 1440), (2048, 2048), (2560, 2560), (3840, 3840) # 1:1
|
177 |
+
]
|
178 |
+
choices = []
|
179 |
+
for w, h in common_resolutions:
|
180 |
+
if abs(w/h - aspect_ratio) < 0.1: # Allow some tolerance for aspect ratio
|
181 |
+
scale = max(w/width, h/height)
|
182 |
+
if scale > 1:
|
183 |
+
choices.append(f"{w}x{h} ({scale:.2f}x)")
|
184 |
+
choices.append("Custom")
|
185 |
+
return gr.update(choices=choices, value=choices[1] if len(choices) > 1 else "Custom")
|
186 |
+
return gr.update(choices=["Custom"], value="Custom")
|
187 |
+
|
188 |
# Improved UI design
|
189 |
css = """
|
190 |
.container {max-width: 1200px; margin: auto; padding: 20px;}
|
191 |
.input-image {width: 100%; max-height: 500px; object-fit: contain;}
|
192 |
.output-gallery {display: flex; flex-wrap: wrap; justify-content: center;}
|
193 |
.output-image {margin: 10px; max-width: 45%; height: auto;}
|
194 |
+
.gr-form {border: 1px solid #e0e0e0; border-radius: 8px; padding: 16px; margin-bottom: 16px;}
|
195 |
"""
|
196 |
|
197 |
with gr.Blocks(css=css) as block:
|
|
|
200 |
with gr.Row():
|
201 |
with gr.Column(scale=1):
|
202 |
input_image = gr.Image(type="pil", label="Input Image", elem_classes="input-image")
|
203 |
+
sr_scale = gr.Dropdown(
|
204 |
+
label="Output Resolution",
|
205 |
+
choices=["Custom"],
|
206 |
+
value="Custom",
|
207 |
+
interactive=True
|
208 |
+
)
|
209 |
+
custom_scale = gr.Slider(
|
210 |
+
label="Custom Scale",
|
211 |
+
minimum=1,
|
212 |
+
maximum=8,
|
213 |
+
value=4,
|
214 |
+
step=0.1,
|
215 |
+
visible=True
|
216 |
+
)
|
217 |
output_resolution = gr.Markdown("Upload an image to see the output resolution")
|
218 |
run_button = gr.Button(value="Run", variant="primary")
|
219 |
|
|
|
240 |
with gr.Row():
|
241 |
result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery", elem_classes="output-gallery")
|
242 |
|
243 |
+
def update_custom_scale(choice):
|
244 |
+
return gr.update(visible=choice == "Custom")
|
245 |
+
|
246 |
+
sr_scale.change(update_custom_scale, inputs=[sr_scale], outputs=[custom_scale])
|
247 |
+
|
248 |
+
def get_scale_value(choice, custom):
|
249 |
+
if choice == "Custom":
|
250 |
+
return custom
|
251 |
+
return float(choice.split()[-1].strip("()x"))
|
252 |
+
|
253 |
inputs = [
|
254 |
input_image, num_samples, sr_scale, strength, positive_prompt, negative_prompt,
|
255 |
cfg_scale, steps, use_color_fix, seed, tile_diffusion, tile_diffusion_size,
|
256 |
tile_diffusion_stride, tile_vae, vae_encoder_tile_size, vae_decoder_tile_size,
|
257 |
]
|
258 |
+
run_button.click(
|
259 |
+
fn=lambda *args: process(*args[:1], args[1], get_scale_value(args[2], args[-1]), *args[3:-1]),
|
260 |
+
inputs=inputs + [custom_scale],
|
261 |
+
outputs=[result_gallery]
|
262 |
+
)
|
263 |
|
264 |
+
input_image.change(
|
265 |
+
update_scale_choices,
|
266 |
+
inputs=[input_image],
|
267 |
+
outputs=[sr_scale]
|
268 |
+
)
|
269 |
+
|
270 |
+
input_image.change(
|
271 |
+
update_output_resolution,
|
272 |
+
inputs=[input_image, sr_scale],
|
273 |
+
outputs=[output_resolution]
|
274 |
+
)
|
275 |
+
sr_scale.change(
|
276 |
+
update_output_resolution,
|
277 |
+
inputs=[input_image, sr_scale],
|
278 |
+
outputs=[output_resolution]
|
279 |
+
)
|
280 |
|
281 |
input_image.change(
|
282 |
lambda x: gr.update(interactive=x is not None),
|