zhiweili
commited on
Commit
•
b312a31
1
Parent(s):
fe60d00
test app base
Browse files- app.py +1 -1
- app_base.py +118 -0
- app_haircolor.py +2 -2
- inversion_run_base.py +219 -0
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
# from app_base import create_demo as create_demo_face
|
4 |
-
from
|
5 |
|
6 |
with gr.Blocks(css="style.css") as demo:
|
7 |
with gr.Tabs():
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
# from app_base import create_demo as create_demo_face
|
4 |
+
from app_base import create_demo as create_demo_haircolor
|
5 |
|
6 |
with gr.Blocks(css="style.css") as demo:
|
7 |
with gr.Tabs():
|
app_base.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import gradio as gr
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
from segment_utils import(
|
8 |
+
segment_image,
|
9 |
+
restore_result,
|
10 |
+
)
|
11 |
+
from enhance_utils import enhance_image
|
12 |
+
|
13 |
+
DEFAULT_SRC_PROMPT = "a woman, photo"
|
14 |
+
DEFAULT_EDIT_PROMPT = "a beautiful woman, photo, hollywood style face, 8k, high quality"
|
15 |
+
|
16 |
+
DEFAULT_CATEGORY = "hair"
|
17 |
+
|
18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
|
20 |
+
def create_demo() -> gr.Blocks:
|
21 |
+
from inversion_run_base import run as base_run
|
22 |
+
|
23 |
+
@spaces.GPU(duration=10)
|
24 |
+
def image_to_image(
|
25 |
+
input_image: Image,
|
26 |
+
input_image_prompt: str,
|
27 |
+
edit_prompt: str,
|
28 |
+
seed: int,
|
29 |
+
w1: float,
|
30 |
+
num_steps: int,
|
31 |
+
start_step: int,
|
32 |
+
guidance_scale: float,
|
33 |
+
strength: float,
|
34 |
+
generate_size: int,
|
35 |
+
):
|
36 |
+
w2 = 1.0
|
37 |
+
run_task_time = 0
|
38 |
+
time_cost_str = ''
|
39 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
40 |
+
run_model = base_run
|
41 |
+
res_image = run_model(
|
42 |
+
input_image,
|
43 |
+
input_image_prompt,
|
44 |
+
edit_prompt,
|
45 |
+
generate_size,
|
46 |
+
seed,
|
47 |
+
w1,
|
48 |
+
w2,
|
49 |
+
num_steps,
|
50 |
+
start_step,
|
51 |
+
guidance_scale,
|
52 |
+
strength,
|
53 |
+
)
|
54 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
55 |
+
enhanced_image = enhance_image(res_image, False)
|
56 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
57 |
+
|
58 |
+
return enhanced_image, res_image, time_cost_str
|
59 |
+
|
60 |
+
def get_time_cost(run_task_time, time_cost_str):
|
61 |
+
now_time = int(time.time()*1000)
|
62 |
+
if run_task_time == 0:
|
63 |
+
time_cost_str = 'start'
|
64 |
+
else:
|
65 |
+
if time_cost_str != '':
|
66 |
+
time_cost_str += f'-->'
|
67 |
+
time_cost_str += f'{now_time - run_task_time}'
|
68 |
+
run_task_time = now_time
|
69 |
+
return run_task_time, time_cost_str
|
70 |
+
|
71 |
+
with gr.Blocks() as demo:
|
72 |
+
croper = gr.State()
|
73 |
+
with gr.Row():
|
74 |
+
with gr.Column():
|
75 |
+
input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
|
76 |
+
edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
|
77 |
+
category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
|
78 |
+
with gr.Column():
|
79 |
+
num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
|
80 |
+
start_step = gr.Slider(minimum=1, maximum=100, value=15, step=1, label="Start Step")
|
81 |
+
strength = gr.Slider(minimum=0, maximum=2, value=0.3, step=0.1, label="Strength")
|
82 |
+
with gr.Accordion("Advanced Options", open=False):
|
83 |
+
guidance_scale = gr.Slider(minimum=0, maximum=20, value=0, step=0.5, label="Guidance Scale")
|
84 |
+
generate_size = gr.Number(label="Generate Size", value=1024)
|
85 |
+
mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
|
86 |
+
mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
|
87 |
+
with gr.Column():
|
88 |
+
seed = gr.Number(label="Seed", value=8)
|
89 |
+
w1 = gr.Number(label="W1", value=2)
|
90 |
+
g_btn = gr.Button("Edit Image")
|
91 |
+
|
92 |
+
with gr.Row():
|
93 |
+
with gr.Column():
|
94 |
+
input_image = gr.Image(label="Input Image", type="pil")
|
95 |
+
with gr.Column():
|
96 |
+
restored_image = gr.Image(label="Restored Image", type="pil", interactive=False)
|
97 |
+
download_path = gr.File(label="Download the output image", interactive=False)
|
98 |
+
with gr.Column():
|
99 |
+
origin_area_image = gr.Image(label="Origin Area Image", type="pil", interactive=False)
|
100 |
+
enhanced_image = gr.Image(label="Enhanced Image", type="pil", interactive=False)
|
101 |
+
generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
|
102 |
+
generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
|
103 |
+
|
104 |
+
g_btn.click(
|
105 |
+
fn=segment_image,
|
106 |
+
inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
|
107 |
+
outputs=[origin_area_image, croper],
|
108 |
+
).success(
|
109 |
+
fn=image_to_image,
|
110 |
+
inputs=[origin_area_image, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, strength, generate_size],
|
111 |
+
outputs=[enhanced_image, generated_image, generated_cost],
|
112 |
+
).success(
|
113 |
+
fn=restore_result,
|
114 |
+
inputs=[croper, category, enhanced_image],
|
115 |
+
outputs=[restored_image, download_path],
|
116 |
+
)
|
117 |
+
|
118 |
+
return demo
|
app_haircolor.py
CHANGED
@@ -12,8 +12,8 @@ from enhance_utils import enhance_image
|
|
12 |
from inversion_run_adapter import run as adapter_run
|
13 |
|
14 |
|
15 |
-
DEFAULT_SRC_PROMPT = "
|
16 |
-
DEFAULT_EDIT_PROMPT = "
|
17 |
|
18 |
DEFAULT_CATEGORY = "hair"
|
19 |
|
|
|
12 |
from inversion_run_adapter import run as adapter_run
|
13 |
|
14 |
|
15 |
+
DEFAULT_SRC_PROMPT = "RAW photo"
|
16 |
+
DEFAULT_EDIT_PROMPT = "RAW photo, Fujifilm XT3, sharp hair, high resolution hair, hair tones, natural hair, magazine hair, white color hair"
|
17 |
|
18 |
DEFAULT_CATEGORY = "hair"
|
19 |
|
inversion_run_base.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from diffusers import (
|
4 |
+
DDPMScheduler,
|
5 |
+
StableDiffusionXLImg2ImgPipeline,
|
6 |
+
)
|
7 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps, retrieve_latents
|
8 |
+
from PIL import Image
|
9 |
+
from inversion_utils import get_ddpm_inversion_scheduler, create_xts
|
10 |
+
from config import get_config, get_num_steps_actual
|
11 |
+
from functools import partial
|
12 |
+
from compel import Compel, ReturnedEmbeddingsType
|
13 |
+
|
14 |
+
class Object(object):
|
15 |
+
pass
|
16 |
+
|
17 |
+
args = Object()
|
18 |
+
args.images_paths = None
|
19 |
+
args.images_folder = None
|
20 |
+
args.force_use_cpu = False
|
21 |
+
args.folder_name = 'test_measure_time'
|
22 |
+
args.config_from_file = 'run_configs/noise_shift_guidance_1_5.yaml'
|
23 |
+
args.save_intermediate_results = False
|
24 |
+
args.batch_size = None
|
25 |
+
args.skip_p_to_p = True
|
26 |
+
args.only_p_to_p = False
|
27 |
+
args.fp16 = False
|
28 |
+
args.prompts_file = 'dataset_measure_time/dataset.json'
|
29 |
+
args.images_in_prompts_file = None
|
30 |
+
args.seed = 986
|
31 |
+
args.time_measure_n = 1
|
32 |
+
|
33 |
+
|
34 |
+
assert (
|
35 |
+
args.batch_size is None or args.save_intermediate_results is False
|
36 |
+
), "save_intermediate_results is not implemented for batch_size > 1"
|
37 |
+
|
38 |
+
generator = None
|
39 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
40 |
+
|
41 |
+
# BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
|
42 |
+
BASE_MODEL = "stabilityai/sdxl-turbo"
|
43 |
+
|
44 |
+
|
45 |
+
pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
46 |
+
BASE_MODEL,
|
47 |
+
torch_dtype=torch.float16,
|
48 |
+
variant="fp16",
|
49 |
+
use_safetensors=True,
|
50 |
+
)
|
51 |
+
pipeline = pipeline.to(device)
|
52 |
+
|
53 |
+
pipeline.scheduler = DDPMScheduler.from_pretrained(
|
54 |
+
BASE_MODEL,
|
55 |
+
subfolder="scheduler",
|
56 |
+
)
|
57 |
+
|
58 |
+
config = get_config(args)
|
59 |
+
|
60 |
+
compel_proc = Compel(
|
61 |
+
tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] ,
|
62 |
+
text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
|
63 |
+
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
|
64 |
+
requires_pooled=[False, True]
|
65 |
+
)
|
66 |
+
|
67 |
+
def run(
|
68 |
+
input_image:Image,
|
69 |
+
src_prompt:str,
|
70 |
+
tgt_prompt:str,
|
71 |
+
generate_size:int,
|
72 |
+
seed:int,
|
73 |
+
w1:float,
|
74 |
+
w2:float,
|
75 |
+
num_steps:int,
|
76 |
+
start_step:int,
|
77 |
+
guidance_scale:float,
|
78 |
+
strength:float,
|
79 |
+
):
|
80 |
+
generator = torch.Generator().manual_seed(seed)
|
81 |
+
|
82 |
+
config.num_steps_inversion = num_steps
|
83 |
+
config.step_start = start_step
|
84 |
+
num_steps_actual = get_num_steps_actual(config)
|
85 |
+
|
86 |
+
|
87 |
+
num_steps_inversion = config.num_steps_inversion
|
88 |
+
denoising_start = (num_steps_inversion - num_steps_actual) / num_steps_inversion
|
89 |
+
print(f"-------->num_steps_inversion: {num_steps_inversion} num_steps_actual: {num_steps_actual} denoising_start: {denoising_start}")
|
90 |
+
|
91 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
92 |
+
pipeline.scheduler, num_steps_inversion, device, None
|
93 |
+
)
|
94 |
+
timesteps, num_inference_steps = pipeline.get_timesteps(
|
95 |
+
num_inference_steps=num_inference_steps,
|
96 |
+
denoising_start=denoising_start,
|
97 |
+
strength=strength,
|
98 |
+
device=device,
|
99 |
+
)
|
100 |
+
timesteps = timesteps.type(torch.int64)
|
101 |
+
|
102 |
+
timesteps = [torch.tensor(t) for t in timesteps.tolist()]
|
103 |
+
timesteps_len = len(timesteps)
|
104 |
+
config.step_start = start_step + num_steps_actual - timesteps_len
|
105 |
+
num_steps_actual = timesteps_len
|
106 |
+
config.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
|
107 |
+
print(f"-------->num_steps_inversion: {num_steps_inversion} num_steps_actual: {num_steps_actual} step_start: {config.step_start}")
|
108 |
+
print(f"-------->timesteps len: {len(timesteps)} max_norm_zs len: {len(config.max_norm_zs)}")
|
109 |
+
pipeline.__call__ = partial(
|
110 |
+
pipeline.__call__,
|
111 |
+
num_inference_steps=num_steps_inversion,
|
112 |
+
guidance_scale=guidance_scale,
|
113 |
+
generator=generator,
|
114 |
+
denoising_start=denoising_start,
|
115 |
+
strength=strength,
|
116 |
+
)
|
117 |
+
|
118 |
+
x_0_image = input_image
|
119 |
+
x_0 = encode_image(x_0_image, pipeline)
|
120 |
+
x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False)
|
121 |
+
x_ts = [xt.to(dtype=torch.float16) for xt in x_ts]
|
122 |
+
latents = [x_ts[0]]
|
123 |
+
x_ts_c_hat = [None]
|
124 |
+
config.ws1 = [w1] * num_steps_actual
|
125 |
+
config.ws2 = [w2] * num_steps_actual
|
126 |
+
pipeline.scheduler = get_ddpm_inversion_scheduler(
|
127 |
+
pipeline.scheduler,
|
128 |
+
config.step_function,
|
129 |
+
config,
|
130 |
+
timesteps,
|
131 |
+
config.save_timesteps,
|
132 |
+
latents,
|
133 |
+
x_ts,
|
134 |
+
x_ts_c_hat,
|
135 |
+
args.save_intermediate_results,
|
136 |
+
pipeline,
|
137 |
+
x_0,
|
138 |
+
v1s_images := [],
|
139 |
+
v2s_images := [],
|
140 |
+
deltas_images := [],
|
141 |
+
v1_x0s := [],
|
142 |
+
v2_x0s := [],
|
143 |
+
deltas_x0s := [],
|
144 |
+
"res12",
|
145 |
+
image_name="im_name",
|
146 |
+
time_measure_n=args.time_measure_n,
|
147 |
+
)
|
148 |
+
latent = latents[0].expand(3, -1, -1, -1)
|
149 |
+
prompt = [src_prompt, src_prompt, tgt_prompt]
|
150 |
+
conditioning, pooled = compel_proc(prompt)
|
151 |
+
image = pipeline.__call__(
|
152 |
+
image=latent,
|
153 |
+
prompt_embeds=conditioning,
|
154 |
+
pooled_prompt_embeds=pooled,
|
155 |
+
eta=1,
|
156 |
+
).images
|
157 |
+
return image[2]
|
158 |
+
|
159 |
+
def encode_image(image, pipe):
|
160 |
+
image = pipe.image_processor.preprocess(image)
|
161 |
+
originDtype = pipe.dtype
|
162 |
+
image = image.to(device=device, dtype=originDtype)
|
163 |
+
|
164 |
+
if pipe.vae.config.force_upcast:
|
165 |
+
image = image.float()
|
166 |
+
pipe.vae.to(dtype=torch.float32)
|
167 |
+
|
168 |
+
if isinstance(generator, list):
|
169 |
+
init_latents = [
|
170 |
+
retrieve_latents(pipe.vae.encode(image[i : i + 1]), generator=generator[i])
|
171 |
+
for i in range(1)
|
172 |
+
]
|
173 |
+
init_latents = torch.cat(init_latents, dim=0)
|
174 |
+
else:
|
175 |
+
init_latents = retrieve_latents(pipe.vae.encode(image), generator=generator)
|
176 |
+
|
177 |
+
if pipe.vae.config.force_upcast:
|
178 |
+
pipe.vae.to(originDtype)
|
179 |
+
|
180 |
+
init_latents = init_latents.to(originDtype)
|
181 |
+
init_latents = pipe.vae.config.scaling_factor * init_latents
|
182 |
+
|
183 |
+
return init_latents.to(dtype=torch.float16)
|
184 |
+
|
185 |
+
def get_timesteps(pipe, num_inference_steps, strength, device, denoising_start=None):
|
186 |
+
# get the original timestep using init_timestep
|
187 |
+
if denoising_start is None:
|
188 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
189 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
190 |
+
else:
|
191 |
+
t_start = 0
|
192 |
+
|
193 |
+
timesteps = pipe.scheduler.timesteps[t_start * pipe.scheduler.order :]
|
194 |
+
|
195 |
+
# Strength is irrelevant if we directly request a timestep to start at;
|
196 |
+
# that is, strength is determined by the denoising_start instead.
|
197 |
+
if denoising_start is not None:
|
198 |
+
discrete_timestep_cutoff = int(
|
199 |
+
round(
|
200 |
+
pipe.scheduler.config.num_train_timesteps
|
201 |
+
- (denoising_start * pipe.scheduler.config.num_train_timesteps)
|
202 |
+
)
|
203 |
+
)
|
204 |
+
|
205 |
+
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
|
206 |
+
if pipe.scheduler.order == 2 and num_inference_steps % 2 == 0:
|
207 |
+
# if the scheduler is a 2nd order scheduler we might have to do +1
|
208 |
+
# because `num_inference_steps` might be even given that every timestep
|
209 |
+
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
|
210 |
+
# mean that we cut the timesteps in the middle of the denoising step
|
211 |
+
# (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
|
212 |
+
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
|
213 |
+
num_inference_steps = num_inference_steps + 1
|
214 |
+
|
215 |
+
# because t_n+1 >= t_n, we slice the timesteps starting from the end
|
216 |
+
timesteps = timesteps[-num_inference_steps:]
|
217 |
+
return timesteps, num_inference_steps
|
218 |
+
|
219 |
+
return timesteps, num_inference_steps - t_start
|