GrayShine commited on
Commit
e574f5a
·
verified ·
1 Parent(s): 2e5e07d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +373 -0
app.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from gradio.themes.utils import colors, fonts, sizes
4
+ import argparse
5
+ from omegaconf import OmegaConf
6
+ import os
7
+ from models import get_models
8
+ from diffusers.utils.import_utils import is_xformers_available
9
+ from tca.tca_transform import tca_transform_model
10
+ from diffusers.models import AutoencoderKL
11
+ from models.clip import TextEmbedder
12
+ from datasets import video_transforms
13
+ from torchvision import transforms
14
+ from utils import mask_generation_before
15
+ from backend import auto_inpainting
16
+ from einops import rearrange
17
+ import torchvision
18
+ import sys
19
+ from PIL import Image
20
+ from ip_adapter.ip_adapter_transform import ip_scale_set, ip_transform_model
21
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
22
+ from transformers.image_transforms import convert_to_rgb
23
+ try:
24
+ import utils
25
+
26
+ from diffusion import create_diffusion
27
+ from download import find_model
28
+ except:
29
+ # sys.path.append(os.getcwd())
30
+ sys.path.append(os.path.split(sys.path[0])[0])
31
+ # 代码解释
32
+ # sys.path[0] : 得到C:\Users\maxu\Desktop\blog_test\pakage2
33
+ # os.path.split(sys.path[0]) : 得到['C:\Users\maxu\Desktop\blog_test',pakage2']
34
+ # mmcls 里面跨包引用是因为安装了mmcls
35
+
36
+ import utils
37
+
38
+ from diffusion import create_diffusion
39
+ from download import find_model
40
+
41
+
42
+ def auto_inpainting(video_input, masked_video, mask, prompt, image, vae, text_encoder, image_encoder, diffusion, model, device, cfg_scale, img_cfg_scale, negative_prompt=""):
43
+ global use_fp16
44
+ image_prompt_embeds = None
45
+ if prompt is None:
46
+ prompt = ""
47
+ if image is not None:
48
+ clip_image = clip_image_processor(images=image, return_tensors="pt").pixel_values
49
+ clip_image_embeds = image_encoder(clip_image.to(device)).image_embeds
50
+ uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds).to(device)
51
+ image_prompt_embeds = torch.cat([clip_image_embeds, uncond_clip_image_embeds], dim=0)
52
+ image_prompt_embeds = rearrange(image_prompt_embeds, '(b n) c -> b n c', b=2).contiguous()
53
+ model = ip_scale_set(model, img_cfg_scale)
54
+ if use_fp16:
55
+ image_prompt_embeds = image_prompt_embeds.to(dtype=torch.float16)
56
+ b, f, c, h, w = video_input.shape
57
+ latent_h = video_input.shape[-2] // 8
58
+ latent_w = video_input.shape[-1] // 8
59
+
60
+ if use_fp16:
61
+ z = torch.randn(1, 4, 16, latent_h, latent_w, dtype=torch.float16, device=device) # b,c,f,h,w
62
+ masked_video = masked_video.to(dtype=torch.float16)
63
+ mask = mask.to(dtype=torch.float16)
64
+ else:
65
+ z = torch.randn(1, 4, 16, latent_h, latent_w, device=device) # b,c,f,h,w
66
+
67
+ masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
68
+ masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
69
+ masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
70
+ mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
71
+ masked_video = torch.cat([masked_video] * 2)
72
+ mask = torch.cat([mask] * 2)
73
+ z = torch.cat([z] * 2)
74
+ prompt_all = [prompt] + [negative_prompt]
75
+
76
+ text_prompt = text_encoder(text_prompts=prompt_all, train=False)
77
+ model_kwargs = dict(encoder_hidden_states=text_prompt,
78
+ class_labels=None,
79
+ cfg_scale=cfg_scale,
80
+ use_fp16=use_fp16,
81
+ ip_hidden_states=image_prompt_embeds)
82
+
83
+ # Sample images:
84
+ samples = diffusion.ddim_sample_loop(
85
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
86
+ mask=mask, x_start=masked_video, use_concat=True
87
+ )
88
+ samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
89
+ if use_fp16:
90
+ samples = samples.to(dtype=torch.float16)
91
+
92
+ video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
93
+ video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
94
+ return video_clip
95
+
96
+
97
+ def auto_inpainting_temp_split(video_input, masked_video, mask, prompt, image, vae, text_encoder, image_encoder, diffusion, model, device, scfg_scale, tcfg_scale, img_cfg_scale, negative_prompt=""):
98
+ global use_fp16
99
+ image_prompt_embeds = None
100
+ if prompt is None:
101
+ prompt = ""
102
+ if image is not None:
103
+ clip_image = clip_image_processor(images=image, return_tensors="pt").pixel_values
104
+ clip_image_embeds = image_encoder(clip_image.to(device)).image_embeds
105
+ uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds).to(device)
106
+ image_prompt_embeds = torch.cat([clip_image_embeds, clip_image_embeds, uncond_clip_image_embeds], dim=0)
107
+ image_prompt_embeds = rearrange(image_prompt_embeds, '(b n) c -> b n c', b=3).contiguous()
108
+ model = ip_scale_set(model, img_cfg_scale)
109
+ if use_fp16:
110
+ image_prompt_embeds = image_prompt_embeds.to(dtype=torch.float16)
111
+ b, f, c, h, w = video_input.shape
112
+ latent_h = video_input.shape[-2] // 8
113
+ latent_w = video_input.shape[-1] // 8
114
+
115
+ if use_fp16:
116
+ z = torch.randn(1, 4, 16, latent_h, latent_w, dtype=torch.float16, device=device) # b,c,f,h,w
117
+ masked_video = masked_video.to(dtype=torch.float16)
118
+ mask = mask.to(dtype=torch.float16)
119
+ else:
120
+ z = torch.randn(1, 4, 16, latent_h, latent_w, device=device) # b,c,f,h,w
121
+
122
+ masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
123
+ masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
124
+ masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
125
+ mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
126
+ masked_video = torch.cat([masked_video] * 3)
127
+ mask = torch.cat([mask] * 3)
128
+ z = torch.cat([z] * 3)
129
+ prompt_all = [prompt] + [prompt] + [negative_prompt]
130
+ prompt_temp = [prompt] + [""] + [""]
131
+
132
+ text_prompt = text_encoder(text_prompts=prompt_all, train=False)
133
+ temporal_text_prompt = text_encoder(text_prompts=prompt_temp, train=False)
134
+ model_kwargs = dict(encoder_hidden_states=text_prompt,
135
+ class_labels=None,
136
+ scfg_scale=scfg_scale,
137
+ tcfg_scale=tcfg_scale,
138
+ use_fp16=use_fp16,
139
+ ip_hidden_states=image_prompt_embeds,
140
+ encoder_temporal_hidden_states=temporal_text_prompt)
141
+
142
+ # Sample images:
143
+ samples = diffusion.ddim_sample_loop(
144
+ model.forward_with_cfg_temp_split, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
145
+ mask=mask, x_start=masked_video, use_concat=True
146
+ )
147
+ samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
148
+ if use_fp16:
149
+ samples = samples.to(dtype=torch.float16)
150
+
151
+ video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
152
+ video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
153
+ return video_clip
154
+
155
+
156
+ # ========================================
157
+ # Model Initialization
158
+ # ========================================
159
+ device = None
160
+ output_path = None
161
+ use_fp16 = False
162
+ model = None
163
+ vae = None
164
+ text_encoder = None
165
+ image_encoder = None
166
+ clip_image_processor = None
167
+ def init_model():
168
+ global device
169
+ global output_path
170
+ global use_fp16
171
+ global model
172
+ global diffusion
173
+ global vae
174
+ global text_encoder
175
+ global image_encoder
176
+ global clip_image_processor
177
+ print('Initializing ShowMaker', flush=True)
178
+ parser = argparse.ArgumentParser()
179
+ parser.add_argument("--config", type=str, default="./configs/sample_mask.yaml")
180
+ args = parser.parse_args()
181
+ args = OmegaConf.load(args.config)
182
+ device = "cuda" if torch.cuda.is_available() else "cpu"
183
+ output_path = args.save_img_path
184
+ # Load model:
185
+ latent_h = args.image_size[0] // 8
186
+ latent_w = args.image_size[1] // 8
187
+ args.image_h = args.image_size[0]
188
+ args.image_w = args.image_size[1]
189
+ args.latent_h = latent_h
190
+ args.latent_w = latent_w
191
+ print('loading model')
192
+ model = get_models(True, args).to(device)
193
+ model = tca_transform_model(model).to(device)
194
+ model = ip_transform_model(model).to(device)
195
+ if args.use_compile:
196
+ model = torch.compile(model)
197
+ if args.enable_xformers_memory_efficient_attention:
198
+ if is_xformers_available():
199
+ model.enable_xformers_memory_efficient_attention()
200
+ print("xformer!")
201
+ else:
202
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
203
+ ckpt_path = args.ckpt
204
+ state_dict = find_model(ckpt_path)
205
+ model.load_state_dict(state_dict)
206
+ print('loading succeed')
207
+ model.eval() # important!
208
+ pretrained_model_path = args.pretrained_model_path
209
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
210
+ text_encoder = TextEmbedder(tokenizer_path=pretrained_model_path + "tokenizer",
211
+ encoder_path=pretrained_model_path + "text_encoder").to(device)
212
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path).to(device)
213
+ clip_image_processor = CLIPImageProcessor()
214
+ if args.use_fp16:
215
+ print('Warnning: using half percision for inferencing!')
216
+ vae.to(dtype=torch.float16)
217
+ model.to(dtype=torch.float16)
218
+ text_encoder.to(dtype=torch.float16)
219
+ image_encoder.to(dtype=torch.float16)
220
+ use_fp16 = True
221
+ print('Initialization Finished')
222
+ init_model()
223
+
224
+
225
+ # ========================================
226
+ # Video Generation
227
+ # ========================================
228
+ def video_generation(text, image, scfg_scale, tcfg_scale, img_cfg_scale, diffusion):
229
+ with torch.no_grad():
230
+ print("begin generation", flush=True)
231
+ transform_video = transforms.Compose([
232
+ video_transforms.ToTensorVideo(), # TCHW
233
+ video_transforms.WebVideo320512((320, 512)),
234
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
235
+ ])
236
+ video_frames = torch.zeros(16, 3, 320, 512, dtype=torch.uint8)
237
+ video_frames = transform_video(video_frames)
238
+ video_input = video_frames.to(device).unsqueeze(0) # b,f,c,h,w
239
+ mask = mask_generation_before("all", video_input.shape, video_input.dtype, device)
240
+ masked_video = video_input * (mask == 0)
241
+ if image is not None:
242
+ print(image.shape, flush=True)
243
+ # image = Image.open(image)
244
+ if scfg_scale == tcfg_scale:
245
+ video_clip = auto_inpainting(video_input, masked_video, mask, text, image, vae, text_encoder, image_encoder, diffusion, model, device, scfg_scale, img_cfg_scale)
246
+ else:
247
+ video_clip = auto_inpainting_temp_split(video_input, masked_video, mask, text, image, vae, text_encoder, image_encoder, diffusion, model, device, scfg_scale, tcfg_scale, img_cfg_scale)
248
+ video_clip = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
249
+ video_path = os.path.join(output_path, 'video.mp4')
250
+ torchvision.io.write_video(video_path, video_clip, fps=8)
251
+ return video_path
252
+
253
+
254
+ # ========================================
255
+ # Video Prediction
256
+ # ========================================
257
+ def video_prediction(text, image, scfg_scale, tcfg_scale, img_cfg_scale, preframe, diffusion):
258
+ with torch.no_grad():
259
+ print("begin generation", flush=True)
260
+ transform_video = transforms.Compose([
261
+ video_transforms.ToTensorVideo(), # TCHW
262
+ # video_transforms.WebVideo320512((320, 512)),
263
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
264
+ ])
265
+ preframe = torch.as_tensor(convert_to_rgb(preframe)).unsqueeze(0)
266
+ zeros = torch.zeros_like(preframe)
267
+ video_frames = torch.cat([preframe] + [zeros] * 15, dim=0).permute(0, 3, 1, 2)
268
+ H_scale = 320 / video_frames.shape[2]
269
+ W_scale = 512 / video_frames.shape[3]
270
+ scale_ = H_scale
271
+ if W_scale < H_scale:
272
+ scale_ = W_scale
273
+ video_frames = torch.nn.functional.interpolate(video_frames, scale_factor=scale_, mode="bilinear", align_corners=False)
274
+ video_frames = transform_video(video_frames)
275
+ video_input = video_frames.to(device).unsqueeze(0) # b,f,c,h,w
276
+ mask = mask_generation_before("first1", video_input.shape, video_input.dtype, device)
277
+ masked_video = video_input * (mask == 0)
278
+ if image is not None:
279
+ print(image.shape, flush=True)
280
+ if scfg_scale == tcfg_scale:
281
+ video_clip = auto_inpainting(video_input, masked_video, mask, text, image, vae, text_encoder, image_encoder, diffusion, model, device, scfg_scale, img_cfg_scale)
282
+ else:
283
+ video_clip = auto_inpainting_temp_split(video_input, masked_video, mask, text, image, vae, text_encoder, image_encoder, diffusion, model, device, scfg_scale, tcfg_scale, img_cfg_scale)
284
+ video_clip = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
285
+ video_path = os.path.join(output_path, 'video.mp4')
286
+ torchvision.io.write_video(video_path, video_clip, fps=8)
287
+ return video_path
288
+
289
+
290
+ # ========================================
291
+ # Judge Generation or Prediction
292
+ # ========================================
293
+ def gen_or_pre(text_input, image_input, scfg_scale, tcfg_scale, img_cfg_scale, preframe_input, diffusion_step):
294
+ default_step = [25, 40, 50, 100, 125, 200, 250]
295
+ difference = [abs(item - diffusion_step) for item in default_step]
296
+ diffusion_step = default_step[difference.index(min(difference))]
297
+ diffusion = create_diffusion(str(diffusion_step))
298
+ if preframe_input is None:
299
+ return video_generation(text_input, image_input, scfg_scale, tcfg_scale, img_cfg_scale, diffusion)
300
+ else:
301
+ return video_prediction(text_input, image_input, scfg_scale, tcfg_scale, img_cfg_scale, preframe_input, diffusion)
302
+
303
+
304
+ with gr.Blocks() as demo:
305
+ with gr.Row():
306
+ with gr.Column(visible=True) as input_raws:
307
+ with gr.Row():
308
+ with gr.Column(scale=1.0):
309
+ text_input = gr.Textbox(show_label=True, interactive=True, label="Text prompt").style(container=False)
310
+ with gr.Row():
311
+ with gr.Column(scale=0.5):
312
+ image_input = gr.Image(show_label=True, interactive=True, label="Reference image").style(container=False)
313
+ with gr.Column(scale=0.5):
314
+ preframe_input = gr.Image(show_label=True, interactive=True, label="First frame").style(container=False)
315
+ with gr.Row():
316
+ with gr.Column(scale=1.0):
317
+ scfg_scale = gr.Slider(
318
+ minimum=1,
319
+ maximum=50,
320
+ value=8,
321
+ step=0.1,
322
+ interactive=True,
323
+ label="Spatial Text Guidence Scale",
324
+ )
325
+ with gr.Row():
326
+ with gr.Column(scale=1.0):
327
+ tcfg_scale = gr.Slider(
328
+ minimum=1,
329
+ maximum=50,
330
+ value=6.5,
331
+ step=0.1,
332
+ interactive=True,
333
+ label="Temporal Text Guidence Scale",
334
+ )
335
+ with gr.Row():
336
+ with gr.Column(scale=1.0):
337
+ img_cfg_scale = gr.Slider(
338
+ minimum=0,
339
+ maximum=1,
340
+ value=0.3,
341
+ step=0.005,
342
+ interactive=True,
343
+ label="Image Guidence Scale",
344
+ )
345
+ with gr.Row():
346
+ with gr.Column(scale=1.0):
347
+ diffusion_step = gr.Slider(
348
+ minimum=20,
349
+ maximum=250,
350
+ value=100,
351
+ step=1,
352
+ interactive=True,
353
+ label="Diffusion Step",
354
+ )
355
+ with gr.Row():
356
+ with gr.Column(scale=0.5, min_width=0):
357
+ run = gr.Button("💭Send")
358
+ with gr.Column(scale=0.5, min_width=0):
359
+ clear = gr.Button("🔄Clear️")
360
+ with gr.Column(scale=0.5, visible=True) as video_upload:
361
+ output_video = gr.Video(interactive=False, include_audio=True, elem_id="输出的视频")#.style(height=360)
362
+ # with gr.Column(elem_id="image", scale=0.5) as img_part:
363
+ # with gr.Tab("Video", elem_id='video_tab'):
364
+
365
+ # with gr.Tab("Image", elem_id='image_tab'):
366
+ # up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload").style(height=360)
367
+ # upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
368
+ clear = gr.Button("Restart")
369
+ run.click(gen_or_pre, [text_input, image_input, scfg_scale, tcfg_scale, img_cfg_scale, preframe_input, diffusion_step], [output_video])
370
+
371
+ # demo.launch(share=True, enable_queue=True)
372
+
373
+ demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)