zhiweili commited on
Commit
b312a31
1 Parent(s): fe60d00

test app base

Browse files
Files changed (4) hide show
  1. app.py +1 -1
  2. app_base.py +118 -0
  3. app_haircolor.py +2 -2
  4. inversion_run_base.py +219 -0
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
 
3
  # from app_base import create_demo as create_demo_face
4
- from app_haircolor import create_demo as create_demo_haircolor
5
 
6
  with gr.Blocks(css="style.css") as demo:
7
  with gr.Tabs():
 
1
  import gradio as gr
2
 
3
  # from app_base import create_demo as create_demo_face
4
+ from app_base import create_demo as create_demo_haircolor
5
 
6
  with gr.Blocks(css="style.css") as demo:
7
  with gr.Tabs():
app_base.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import time
4
+ import torch
5
+
6
+ from PIL import Image
7
+ from segment_utils import(
8
+ segment_image,
9
+ restore_result,
10
+ )
11
+ from enhance_utils import enhance_image
12
+
13
+ DEFAULT_SRC_PROMPT = "a woman, photo"
14
+ DEFAULT_EDIT_PROMPT = "a beautiful woman, photo, hollywood style face, 8k, high quality"
15
+
16
+ DEFAULT_CATEGORY = "hair"
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ def create_demo() -> gr.Blocks:
21
+ from inversion_run_base import run as base_run
22
+
23
+ @spaces.GPU(duration=10)
24
+ def image_to_image(
25
+ input_image: Image,
26
+ input_image_prompt: str,
27
+ edit_prompt: str,
28
+ seed: int,
29
+ w1: float,
30
+ num_steps: int,
31
+ start_step: int,
32
+ guidance_scale: float,
33
+ strength: float,
34
+ generate_size: int,
35
+ ):
36
+ w2 = 1.0
37
+ run_task_time = 0
38
+ time_cost_str = ''
39
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
40
+ run_model = base_run
41
+ res_image = run_model(
42
+ input_image,
43
+ input_image_prompt,
44
+ edit_prompt,
45
+ generate_size,
46
+ seed,
47
+ w1,
48
+ w2,
49
+ num_steps,
50
+ start_step,
51
+ guidance_scale,
52
+ strength,
53
+ )
54
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
55
+ enhanced_image = enhance_image(res_image, False)
56
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
57
+
58
+ return enhanced_image, res_image, time_cost_str
59
+
60
+ def get_time_cost(run_task_time, time_cost_str):
61
+ now_time = int(time.time()*1000)
62
+ if run_task_time == 0:
63
+ time_cost_str = 'start'
64
+ else:
65
+ if time_cost_str != '':
66
+ time_cost_str += f'-->'
67
+ time_cost_str += f'{now_time - run_task_time}'
68
+ run_task_time = now_time
69
+ return run_task_time, time_cost_str
70
+
71
+ with gr.Blocks() as demo:
72
+ croper = gr.State()
73
+ with gr.Row():
74
+ with gr.Column():
75
+ input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
76
+ edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
77
+ category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
78
+ with gr.Column():
79
+ num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
80
+ start_step = gr.Slider(minimum=1, maximum=100, value=15, step=1, label="Start Step")
81
+ strength = gr.Slider(minimum=0, maximum=2, value=0.3, step=0.1, label="Strength")
82
+ with gr.Accordion("Advanced Options", open=False):
83
+ guidance_scale = gr.Slider(minimum=0, maximum=20, value=0, step=0.5, label="Guidance Scale")
84
+ generate_size = gr.Number(label="Generate Size", value=1024)
85
+ mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
86
+ mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
87
+ with gr.Column():
88
+ seed = gr.Number(label="Seed", value=8)
89
+ w1 = gr.Number(label="W1", value=2)
90
+ g_btn = gr.Button("Edit Image")
91
+
92
+ with gr.Row():
93
+ with gr.Column():
94
+ input_image = gr.Image(label="Input Image", type="pil")
95
+ with gr.Column():
96
+ restored_image = gr.Image(label="Restored Image", type="pil", interactive=False)
97
+ download_path = gr.File(label="Download the output image", interactive=False)
98
+ with gr.Column():
99
+ origin_area_image = gr.Image(label="Origin Area Image", type="pil", interactive=False)
100
+ enhanced_image = gr.Image(label="Enhanced Image", type="pil", interactive=False)
101
+ generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
102
+ generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
103
+
104
+ g_btn.click(
105
+ fn=segment_image,
106
+ inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
107
+ outputs=[origin_area_image, croper],
108
+ ).success(
109
+ fn=image_to_image,
110
+ inputs=[origin_area_image, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, strength, generate_size],
111
+ outputs=[enhanced_image, generated_image, generated_cost],
112
+ ).success(
113
+ fn=restore_result,
114
+ inputs=[croper, category, enhanced_image],
115
+ outputs=[restored_image, download_path],
116
+ )
117
+
118
+ return demo
app_haircolor.py CHANGED
@@ -12,8 +12,8 @@ from enhance_utils import enhance_image
12
  from inversion_run_adapter import run as adapter_run
13
 
14
 
15
- DEFAULT_SRC_PROMPT = "a woman, with hair"
16
- DEFAULT_EDIT_PROMPT = "a woman, with red hair, 8k, high quality"
17
 
18
  DEFAULT_CATEGORY = "hair"
19
 
 
12
  from inversion_run_adapter import run as adapter_run
13
 
14
 
15
+ DEFAULT_SRC_PROMPT = "RAW photo"
16
+ DEFAULT_EDIT_PROMPT = "RAW photo, Fujifilm XT3, sharp hair, high resolution hair, hair tones, natural hair, magazine hair, white color hair"
17
 
18
  DEFAULT_CATEGORY = "hair"
19
 
inversion_run_base.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers import (
4
+ DDPMScheduler,
5
+ StableDiffusionXLImg2ImgPipeline,
6
+ )
7
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps, retrieve_latents
8
+ from PIL import Image
9
+ from inversion_utils import get_ddpm_inversion_scheduler, create_xts
10
+ from config import get_config, get_num_steps_actual
11
+ from functools import partial
12
+ from compel import Compel, ReturnedEmbeddingsType
13
+
14
+ class Object(object):
15
+ pass
16
+
17
+ args = Object()
18
+ args.images_paths = None
19
+ args.images_folder = None
20
+ args.force_use_cpu = False
21
+ args.folder_name = 'test_measure_time'
22
+ args.config_from_file = 'run_configs/noise_shift_guidance_1_5.yaml'
23
+ args.save_intermediate_results = False
24
+ args.batch_size = None
25
+ args.skip_p_to_p = True
26
+ args.only_p_to_p = False
27
+ args.fp16 = False
28
+ args.prompts_file = 'dataset_measure_time/dataset.json'
29
+ args.images_in_prompts_file = None
30
+ args.seed = 986
31
+ args.time_measure_n = 1
32
+
33
+
34
+ assert (
35
+ args.batch_size is None or args.save_intermediate_results is False
36
+ ), "save_intermediate_results is not implemented for batch_size > 1"
37
+
38
+ generator = None
39
+ device = "cuda" if torch.cuda.is_available() else "cpu"
40
+
41
+ # BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
42
+ BASE_MODEL = "stabilityai/sdxl-turbo"
43
+
44
+
45
+ pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
46
+ BASE_MODEL,
47
+ torch_dtype=torch.float16,
48
+ variant="fp16",
49
+ use_safetensors=True,
50
+ )
51
+ pipeline = pipeline.to(device)
52
+
53
+ pipeline.scheduler = DDPMScheduler.from_pretrained(
54
+ BASE_MODEL,
55
+ subfolder="scheduler",
56
+ )
57
+
58
+ config = get_config(args)
59
+
60
+ compel_proc = Compel(
61
+ tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] ,
62
+ text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
63
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
64
+ requires_pooled=[False, True]
65
+ )
66
+
67
+ def run(
68
+ input_image:Image,
69
+ src_prompt:str,
70
+ tgt_prompt:str,
71
+ generate_size:int,
72
+ seed:int,
73
+ w1:float,
74
+ w2:float,
75
+ num_steps:int,
76
+ start_step:int,
77
+ guidance_scale:float,
78
+ strength:float,
79
+ ):
80
+ generator = torch.Generator().manual_seed(seed)
81
+
82
+ config.num_steps_inversion = num_steps
83
+ config.step_start = start_step
84
+ num_steps_actual = get_num_steps_actual(config)
85
+
86
+
87
+ num_steps_inversion = config.num_steps_inversion
88
+ denoising_start = (num_steps_inversion - num_steps_actual) / num_steps_inversion
89
+ print(f"-------->num_steps_inversion: {num_steps_inversion} num_steps_actual: {num_steps_actual} denoising_start: {denoising_start}")
90
+
91
+ timesteps, num_inference_steps = retrieve_timesteps(
92
+ pipeline.scheduler, num_steps_inversion, device, None
93
+ )
94
+ timesteps, num_inference_steps = pipeline.get_timesteps(
95
+ num_inference_steps=num_inference_steps,
96
+ denoising_start=denoising_start,
97
+ strength=strength,
98
+ device=device,
99
+ )
100
+ timesteps = timesteps.type(torch.int64)
101
+
102
+ timesteps = [torch.tensor(t) for t in timesteps.tolist()]
103
+ timesteps_len = len(timesteps)
104
+ config.step_start = start_step + num_steps_actual - timesteps_len
105
+ num_steps_actual = timesteps_len
106
+ config.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
107
+ print(f"-------->num_steps_inversion: {num_steps_inversion} num_steps_actual: {num_steps_actual} step_start: {config.step_start}")
108
+ print(f"-------->timesteps len: {len(timesteps)} max_norm_zs len: {len(config.max_norm_zs)}")
109
+ pipeline.__call__ = partial(
110
+ pipeline.__call__,
111
+ num_inference_steps=num_steps_inversion,
112
+ guidance_scale=guidance_scale,
113
+ generator=generator,
114
+ denoising_start=denoising_start,
115
+ strength=strength,
116
+ )
117
+
118
+ x_0_image = input_image
119
+ x_0 = encode_image(x_0_image, pipeline)
120
+ x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False)
121
+ x_ts = [xt.to(dtype=torch.float16) for xt in x_ts]
122
+ latents = [x_ts[0]]
123
+ x_ts_c_hat = [None]
124
+ config.ws1 = [w1] * num_steps_actual
125
+ config.ws2 = [w2] * num_steps_actual
126
+ pipeline.scheduler = get_ddpm_inversion_scheduler(
127
+ pipeline.scheduler,
128
+ config.step_function,
129
+ config,
130
+ timesteps,
131
+ config.save_timesteps,
132
+ latents,
133
+ x_ts,
134
+ x_ts_c_hat,
135
+ args.save_intermediate_results,
136
+ pipeline,
137
+ x_0,
138
+ v1s_images := [],
139
+ v2s_images := [],
140
+ deltas_images := [],
141
+ v1_x0s := [],
142
+ v2_x0s := [],
143
+ deltas_x0s := [],
144
+ "res12",
145
+ image_name="im_name",
146
+ time_measure_n=args.time_measure_n,
147
+ )
148
+ latent = latents[0].expand(3, -1, -1, -1)
149
+ prompt = [src_prompt, src_prompt, tgt_prompt]
150
+ conditioning, pooled = compel_proc(prompt)
151
+ image = pipeline.__call__(
152
+ image=latent,
153
+ prompt_embeds=conditioning,
154
+ pooled_prompt_embeds=pooled,
155
+ eta=1,
156
+ ).images
157
+ return image[2]
158
+
159
+ def encode_image(image, pipe):
160
+ image = pipe.image_processor.preprocess(image)
161
+ originDtype = pipe.dtype
162
+ image = image.to(device=device, dtype=originDtype)
163
+
164
+ if pipe.vae.config.force_upcast:
165
+ image = image.float()
166
+ pipe.vae.to(dtype=torch.float32)
167
+
168
+ if isinstance(generator, list):
169
+ init_latents = [
170
+ retrieve_latents(pipe.vae.encode(image[i : i + 1]), generator=generator[i])
171
+ for i in range(1)
172
+ ]
173
+ init_latents = torch.cat(init_latents, dim=0)
174
+ else:
175
+ init_latents = retrieve_latents(pipe.vae.encode(image), generator=generator)
176
+
177
+ if pipe.vae.config.force_upcast:
178
+ pipe.vae.to(originDtype)
179
+
180
+ init_latents = init_latents.to(originDtype)
181
+ init_latents = pipe.vae.config.scaling_factor * init_latents
182
+
183
+ return init_latents.to(dtype=torch.float16)
184
+
185
+ def get_timesteps(pipe, num_inference_steps, strength, device, denoising_start=None):
186
+ # get the original timestep using init_timestep
187
+ if denoising_start is None:
188
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
189
+ t_start = max(num_inference_steps - init_timestep, 0)
190
+ else:
191
+ t_start = 0
192
+
193
+ timesteps = pipe.scheduler.timesteps[t_start * pipe.scheduler.order :]
194
+
195
+ # Strength is irrelevant if we directly request a timestep to start at;
196
+ # that is, strength is determined by the denoising_start instead.
197
+ if denoising_start is not None:
198
+ discrete_timestep_cutoff = int(
199
+ round(
200
+ pipe.scheduler.config.num_train_timesteps
201
+ - (denoising_start * pipe.scheduler.config.num_train_timesteps)
202
+ )
203
+ )
204
+
205
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
206
+ if pipe.scheduler.order == 2 and num_inference_steps % 2 == 0:
207
+ # if the scheduler is a 2nd order scheduler we might have to do +1
208
+ # because `num_inference_steps` might be even given that every timestep
209
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
210
+ # mean that we cut the timesteps in the middle of the denoising step
211
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
212
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
213
+ num_inference_steps = num_inference_steps + 1
214
+
215
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
216
+ timesteps = timesteps[-num_inference_steps:]
217
+ return timesteps, num_inference_steps
218
+
219
+ return timesteps, num_inference_steps - t_start