fffiloni commited on
Commit
454eedf
1 Parent(s): c6d50de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py CHANGED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ort os
2
+ import numpy as np
3
+ import argparse
4
+ import imageio
5
+ import torch
6
+
7
+ from einops import rearrange
8
+ from diffusers import DDIMScheduler, AutoencoderKL
9
+ from transformers import CLIPTextModel, CLIPTokenizer
10
+ # from annotator.canny import CannyDetector
11
+ # from annotator.openpose import OpenposeDetector
12
+ # from annotator.midas import MidasDetector
13
+ # import sys
14
+ # sys.path.insert(0, ".")
15
+ from huggingface_hub import hf_hub_download
16
+ import controlnet_aux
17
+ from controlnet_aux import OpenposeDetector, CannyDetector, MidasDetector
18
+ from controlnet_aux.open_pose.body import Body
19
+
20
+ from models.pipeline_controlvideo import ControlVideoPipeline
21
+ from models.util import save_videos_grid, read_video, get_annotation
22
+ from models.unet import UNet3DConditionModel
23
+ from models.controlnet import ControlNetModel3D
24
+ from models.RIFE.IFNet_HDv3 import IFNet
25
+
26
+
27
+ device = "cuda"
28
+ sd_path = "checkpoints/stable-diffusion-v1-5"
29
+ inter_path = "checkpoints/flownet.pkl"
30
+ controlnet_dict = {
31
+ "pose": "checkpoints/sd-controlnet-openpose",
32
+ "depth": "checkpoints/sd-controlnet-depth",
33
+ "canny": "checkpoints/sd-controlnet-canny",
34
+ }
35
+
36
+ controlnet_parser_dict = {
37
+ "pose": OpenposeDetector,
38
+ "depth": MidasDetector,
39
+ "canny": CannyDetector,
40
+ }
41
+
42
+ POS_PROMPT = " ,best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth"
43
+ NEG_PROMPT = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"
44
+
45
+
46
+
47
+ def get_args():
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument("--prompt", type=str, required=True, help="Text description of target video")
50
+ parser.add_argument("--video_path", type=str, required=True, help="Path to a source video")
51
+ parser.add_argument("--output_path", type=str, default="./outputs", help="Directory of output")
52
+ parser.add_argument("--condition", type=str, default="depth", help="Condition of structure sequence")
53
+ parser.add_argument("--video_length", type=int, default=15, help="Length of synthesized video")
54
+ parser.add_argument("--height", type=int, default=512, help="Height of synthesized video, and should be a multiple of 32")
55
+ parser.add_argument("--width", type=int, default=512, help="Width of synthesized video, and should be a multiple of 32")
56
+ parser.add_argument("--smoother_steps", nargs='+', default=[19, 20], type=int, help="Timesteps at which using interleaved-frame smoother")
57
+ parser.add_argument("--is_long_video", action='store_true', help="Whether to use hierarchical sampler to produce long video")
58
+ parser.add_argument("--seed", type=int, default=42, help="Random seed of generator")
59
+
60
+ args = parser.parse_args()
61
+ return args
62
+
63
+ if __name__ == "__main__":
64
+ args = get_args()
65
+ os.makedirs(args.output_path, exist_ok=True)
66
+
67
+ # Height and width should be a multiple of 32
68
+ args.height = (args.height // 32) * 32
69
+ args.width = (args.width // 32) * 32
70
+
71
+ if args.condition == "pose":
72
+ pretrained_model_or_path = "lllyasviel/ControlNet"
73
+ body_model_path = hf_hub_download(pretrained_model_or_path, "annotator/ckpts/body_pose_model.pth", cache_dir="checkpoints")
74
+ body_estimation = Body(body_model_path)
75
+ annotator = controlnet_parser_dict[args.condition](body_estimation)
76
+ else:
77
+ annotator = controlnet_parser_dict[args.condition]()
78
+
79
+ tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
80
+ text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").to(dtype=torch.float16)
81
+ vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae").to(dtype=torch.float16)
82
+ unet = UNet3DConditionModel.from_pretrained_2d(sd_path, subfolder="unet").to(dtype=torch.float16)
83
+ controlnet = ControlNetModel3D.from_pretrained_2d(controlnet_dict[args.condition]).to(dtype=torch.float16)
84
+ interpolater = IFNet(ckpt_path=inter_path).to(dtype=torch.float16)
85
+ scheduler=DDIMScheduler.from_pretrained(sd_path, subfolder="scheduler")
86
+
87
+ pipe = ControlVideoPipeline(
88
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
89
+ controlnet=controlnet, interpolater=interpolater, scheduler=scheduler,
90
+ )
91
+ pipe.enable_vae_slicing()
92
+ pipe.enable_xformers_memory_efficient_attention()
93
+ pipe.to(device)
94
+
95
+ generator = torch.Generator(device="cuda")
96
+ generator.manual_seed(args.seed)
97
+
98
+ # Step 1. Read a video
99
+ video = read_video(video_path=args.video_path, video_length=args.video_length, width=args.width, height=args.height)
100
+
101
+ # Save source video
102
+ original_pixels = rearrange(video, "(b f) c h w -> b c f h w", b=1)
103
+ save_videos_grid(original_pixels, os.path.join(args.output_path, "source_video.mp4"), rescale=True)
104
+
105
+
106
+ # Step 2. Parse a video to conditional frames
107
+ pil_annotation = get_annotation(video, annotator)
108
+ if args.condition == "depth" and controlnet_aux.__version__ == '0.0.1':
109
+ pil_annotation = [pil_annot[0] for pil_annot in pil_annotation]
110
+
111
+ # Save condition video
112
+ video_cond = [np.array(p).astype(np.uint8) for p in pil_annotation]
113
+ imageio.mimsave(os.path.join(args.output_path, f"{args.condition}_condition.mp4"), video_cond, fps=8)
114
+
115
+ # Reduce memory (optional)
116
+ del annotator; torch.cuda.empty_cache()
117
+
118
+ # Step 3. inference
119
+
120
+ if args.is_long_video:
121
+ window_size = int(np.sqrt(args.video_length))
122
+ sample = pipe.generate_long_video(args.prompt + POS_PROMPT, video_length=args.video_length, frames=pil_annotation,
123
+ num_inference_steps=50, smooth_steps=args.smoother_steps, window_size=window_size,
124
+ generator=generator, guidance_scale=12.5, negative_prompt=NEG_PROMPT,
125
+ width=args.width, height=args.height
126
+ ).videos
127
+ else:
128
+ sample = pipe(args.prompt + POS_PROMPT, video_length=args.video_length, frames=pil_annotation,
129
+ num_inference_steps=50, smooth_steps=args.smoother_steps,
130
+ generator=generator, guidance_scale=12.5, negative_prompt=NEG_PROMPT,
131
+ width=args.width, height=args.height
132
+ ).videos
133
+ save_videos_grid(sample, f"{args.output_path}/{args.prompt}.mp4")