Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,31 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
-
import numpy as np
|
4 |
-
import argparse
|
5 |
-
import imageio
|
6 |
-
import torch
|
7 |
|
8 |
-
from
|
9 |
-
from diffusers import DDIMScheduler, AutoencoderKL
|
10 |
-
from transformers import CLIPTextModel, CLIPTokenizer
|
11 |
-
# from annotator.canny import CannyDetector
|
12 |
-
# from annotator.openpose import OpenposeDetector
|
13 |
-
# from annotator.midas import MidasDetector
|
14 |
-
# import sys
|
15 |
-
# sys.path.insert(0, ".")
|
16 |
-
from huggingface_hub import hf_hub_download, snapshot_download
|
17 |
-
import controlnet_aux
|
18 |
-
from controlnet_aux import OpenposeDetector, CannyDetector, MidasDetector
|
19 |
-
from controlnet_aux.open_pose.body import Body
|
20 |
-
|
21 |
-
from models.pipeline_controlvideo import ControlVideoPipeline
|
22 |
-
from models.util import save_videos_grid, read_video, get_annotation
|
23 |
-
from models.unet import UNet3DConditionModel
|
24 |
-
from models.controlnet import ControlNetModel3D
|
25 |
-
from models.RIFE.IFNet_HDv3 import IFNet
|
26 |
-
|
27 |
-
hf_token = os.environ.get('HF_TOKEN')
|
28 |
-
device = "cuda"
|
29 |
|
30 |
model_ids = [
|
31 |
'runwayml/stable-diffusion-v1-5',
|
@@ -37,122 +13,22 @@ for model_id in model_ids:
|
|
37 |
model_name = model_id.split('/')[-1]
|
38 |
snapshot_download(model_id, local_dir=f'checkpoints/{model_name}')
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
inter_path = "checkpoints/flownet.pkl"
|
43 |
-
controlnet_dict = {
|
44 |
-
"pose": "checkpoints/sd-controlnet-openpose",
|
45 |
-
"depth": "checkpoints/sd-controlnet-depth",
|
46 |
-
"canny": "checkpoints/sd-controlnet-canny",
|
47 |
-
}
|
48 |
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
|
51 |
-
"
|
52 |
-
"depth": MidasDetector,
|
53 |
-
"canny": CannyDetector,
|
54 |
-
}
|
55 |
|
56 |
-
|
57 |
-
NEG_PROMPT = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"
|
58 |
|
59 |
|
60 |
-
|
61 |
-
def get_args():
|
62 |
-
parser = argparse.ArgumentParser()
|
63 |
-
parser.add_argument("--prompt", type=str, required=True, help="Text description of target video")
|
64 |
-
parser.add_argument("--video_path", type=str, required=True, help="Path to a source video")
|
65 |
-
parser.add_argument("--output_path", type=str, default="./outputs", help="Directory of output")
|
66 |
-
parser.add_argument("--condition", type=str, default="depth", help="Condition of structure sequence")
|
67 |
-
parser.add_argument("--video_length", type=int, default=15, help="Length of synthesized video")
|
68 |
-
parser.add_argument("--height", type=int, default=512, help="Height of synthesized video, and should be a multiple of 32")
|
69 |
-
parser.add_argument("--width", type=int, default=512, help="Width of synthesized video, and should be a multiple of 32")
|
70 |
-
parser.add_argument("--smoother_steps", nargs='+', default=[19, 20], type=int, help="Timesteps at which using interleaved-frame smoother")
|
71 |
-
parser.add_argument("--is_long_video", action='store_true', help="Whether to use hierarchical sampler to produce long video")
|
72 |
-
parser.add_argument("--seed", type=int, default=42, help="Random seed of generator")
|
73 |
|
74 |
-
args = parser.parse_args()
|
75 |
-
return args
|
76 |
-
|
77 |
-
def infer(prompt, video_path, condition, video_length, is_long_video):
|
78 |
-
#args = get_args()
|
79 |
-
#os.makedirs(args.output_path, exist_ok=True)
|
80 |
-
|
81 |
-
# Height and width should be a multiple of 32
|
82 |
-
output_path = ""
|
83 |
-
height = 512
|
84 |
-
width = 512
|
85 |
-
height = (height // 32) * 32
|
86 |
-
width = (width // 32) * 32
|
87 |
-
smoother_steps = [19, 20]
|
88 |
-
is_long_video = False
|
89 |
-
seed = 42
|
90 |
-
|
91 |
-
if condition == "pose":
|
92 |
-
pretrained_model_or_path = "lllyasviel/ControlNet"
|
93 |
-
body_model_path = hf_hub_download(pretrained_model_or_path, "annotator/ckpts/body_pose_model.pth", cache_dir="checkpoints")
|
94 |
-
body_estimation = Body(body_model_path)
|
95 |
-
annotator = controlnet_parser_dict[condition](body_estimation)
|
96 |
-
else:
|
97 |
-
annotator = controlnet_parser_dict[condition]()
|
98 |
-
|
99 |
-
tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
|
100 |
-
text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").to(dtype=torch.float16)
|
101 |
-
vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae").to(dtype=torch.float16)
|
102 |
-
unet = UNet3DConditionModel.from_pretrained_2d(sd_path, subfolder="unet").to(dtype=torch.float16)
|
103 |
-
controlnet = ControlNetModel3D.from_pretrained_2d(controlnet_dict[condition]).to(dtype=torch.float16)
|
104 |
-
interpolater = IFNet(ckpt_path=inter_path).to(dtype=torch.float16)
|
105 |
-
scheduler=DDIMScheduler.from_pretrained(sd_path, subfolder="scheduler")
|
106 |
-
|
107 |
-
pipe = ControlVideoPipeline(
|
108 |
-
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
|
109 |
-
controlnet=controlnet, interpolater=interpolater, scheduler=scheduler,
|
110 |
-
)
|
111 |
-
pipe.enable_vae_slicing()
|
112 |
-
pipe.enable_xformers_memory_efficient_attention()
|
113 |
-
pipe.to(device)
|
114 |
-
|
115 |
-
generator = torch.Generator(device="cuda")
|
116 |
-
generator.manual_seed(seed)
|
117 |
-
|
118 |
-
# Step 1. Read a video
|
119 |
-
video = read_video(video_path=video_path, video_length=video_length, width=width, height=height)
|
120 |
-
|
121 |
-
# Save source video
|
122 |
-
original_pixels = rearrange(video, "(b f) c h w -> b c f h w", b=1)
|
123 |
-
save_videos_grid(original_pixels, os.path.join(output_path, "source_video.mp4"), rescale=True)
|
124 |
-
|
125 |
-
|
126 |
-
# Step 2. Parse a video to conditional frames
|
127 |
-
pil_annotation = get_annotation(video, annotator)
|
128 |
-
if condition == "depth" and controlnet_aux.__version__ == '0.0.1':
|
129 |
-
pil_annotation = [pil_annot[0] for pil_annot in pil_annotation]
|
130 |
-
|
131 |
-
# Save condition video
|
132 |
-
video_cond = [np.array(p).astype(np.uint8) for p in pil_annotation]
|
133 |
-
imageio.mimsave(os.path.join(output_path, f"{condition}_condition.mp4"), video_cond, fps=8)
|
134 |
-
|
135 |
-
# Reduce memory (optional)
|
136 |
-
del annotator; torch.cuda.empty_cache()
|
137 |
-
|
138 |
-
# Step 3. inference
|
139 |
-
|
140 |
-
if is_long_video:
|
141 |
-
window_size = int(np.sqrt(video_length))
|
142 |
-
sample = pipe.generate_long_video(prompt + POS_PROMPT, video_length=video_length, frames=pil_annotation,
|
143 |
-
num_inference_steps=50, smooth_steps=args.smoother_steps, window_size=window_size,
|
144 |
-
generator=generator, guidance_scale=12.5, negative_prompt=NEG_PROMPT,
|
145 |
-
width=width, height=height
|
146 |
-
).videos
|
147 |
-
else:
|
148 |
-
sample = pipe(prompt + POS_PROMPT, video_length=video_length, frames=pil_annotation,
|
149 |
-
num_inference_steps=50, smooth_steps=args.smoother_steps,
|
150 |
-
generator=generator, guidance_scale=12.5, negative_prompt=NEG_PROMPT,
|
151 |
-
width=width, height=height
|
152 |
-
).videos
|
153 |
-
save_videos_grid(sample, f"{output_path}/{prompt}.mp4")
|
154 |
|
155 |
-
return f"{output_path}/{prompt}.mp4"
|
156 |
|
157 |
with gr.Blocks() as demo:
|
158 |
with gr.Column():
|
@@ -160,16 +36,16 @@ with gr.Blocks() as demo:
|
|
160 |
video_path = gr.Video(source="upload", type="filepath")
|
161 |
condition = gr.Textbox(label="Condition", value="depth")
|
162 |
video_length = gr.Slider(label="video length", minimum=1, maximum=15, step=1, value=2)
|
163 |
-
seed = gr.Number(label="seed", value=42)
|
164 |
submit_btn = gr.Button("Submit")
|
165 |
-
video_res = gr.Video(label="result")
|
|
|
166 |
|
167 |
-
submit_btn.click(fn=
|
168 |
inputs=[prompt,
|
169 |
video_path,
|
170 |
condition,
|
171 |
-
video_length
|
172 |
-
seed,
|
173 |
],
|
174 |
outputs=[video_res])
|
175 |
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
|
|
|
|
|
|
|
|
3 |
|
4 |
+
from huggingface_hub import snapshot_download
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
model_ids = [
|
7 |
'runwayml/stable-diffusion-v1-5',
|
|
|
13 |
model_name = model_id.split('/')[-1]
|
14 |
snapshot_download(model_id, local_dir=f'checkpoints/{model_name}')
|
15 |
|
16 |
+
import subprocess
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
def run_inference(prompt, video_path, condition, video_length):
|
19 |
+
command = "python inference.py --prompt prompt --condition condition --video_path video_path --output_path 'outputs/' --video_length video_length --smoother_steps 19 20"
|
20 |
+
output = subprocess.check_output(command, shell=True, text=True)
|
21 |
+
output = output.strip() # Remove any leading/trailing whitespace
|
22 |
|
23 |
+
# Process the output as needed
|
24 |
+
print("Command output:", output)
|
|
|
|
|
|
|
25 |
|
26 |
+
return "done"
|
|
|
27 |
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
#return f"{output_path}/{prompt}.mp4"
|
32 |
|
33 |
with gr.Blocks() as demo:
|
34 |
with gr.Column():
|
|
|
36 |
video_path = gr.Video(source="upload", type="filepath")
|
37 |
condition = gr.Textbox(label="Condition", value="depth")
|
38 |
video_length = gr.Slider(label="video length", minimum=1, maximum=15, step=1, value=2)
|
39 |
+
#seed = gr.Number(label="seed", value=42)
|
40 |
submit_btn = gr.Button("Submit")
|
41 |
+
#video_res = gr.Video(label="result")
|
42 |
+
video_res = gr.Textbox(label="result")
|
43 |
|
44 |
+
submit_btn.click(fn=run_inference,
|
45 |
inputs=[prompt,
|
46 |
video_path,
|
47 |
condition,
|
48 |
+
video_length
|
|
|
49 |
],
|
50 |
outputs=[video_res])
|
51 |
|