Spaces:
Runtime error
Runtime error
Haoxin Chen
commited on
Commit
·
0757c55
1
Parent(s):
1e5bda6
add videocontrol code
Browse files- .gitignore +4 -1
- app.py +6 -4
- demo_test.py +1 -1
- input/flamingo.mp4 +0 -0
- videocontrol_test.py +131 -0
.gitignore
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
.DS_Store
|
2 |
*pyc
|
|
|
3 |
__pycache__
|
4 |
*.egg-info
|
5 |
|
6 |
results
|
7 |
-
*.ckpt
|
|
|
|
|
|
1 |
.DS_Store
|
2 |
*pyc
|
3 |
+
.vscode
|
4 |
__pycache__
|
5 |
*.egg-info
|
6 |
|
7 |
results
|
8 |
+
*.ckpt
|
9 |
+
*.pt
|
10 |
+
*.pth
|
app.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import os
|
2 |
import sys
|
3 |
import gradio as gr
|
4 |
-
from
|
|
|
5 |
sys.path.insert(1, os.path.join(sys.path[0], 'lvdm'))
|
6 |
|
7 |
t2v_examples = [
|
@@ -14,12 +15,12 @@ t2v_examples = [
|
|
14 |
]
|
15 |
|
16 |
control_examples = [
|
17 |
-
['
|
18 |
]
|
19 |
|
20 |
def videocrafter_demo(result_dir='./tmp/'):
|
21 |
text2video = Text2Video(result_dir)
|
22 |
-
videocontrol =
|
23 |
with gr.Blocks(analytics_enabled=False) as videocrafter_iface:
|
24 |
gr.Markdown("<div align='center'> <h2> VideoCrafter: A Toolkit for Text-to-Video Generation and Editing </span> </h2> \
|
25 |
<a style='font-size:18px;color: #efefef' href='https://github.com/VideoCrafter/VideoCrafter'> Github </div>")
|
@@ -89,4 +90,5 @@ def videocrafter_demo(result_dir='./tmp/'):
|
|
89 |
if __name__ == "__main__":
|
90 |
result_dir = os.path.join('./', 'results')
|
91 |
videocrafter_iface = videocrafter_demo(result_dir)
|
92 |
-
videocrafter_iface.launch(
|
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
import gradio as gr
|
4 |
+
from videocrafter_test import Text2Video
|
5 |
+
from videocontrol_test import VideoControl
|
6 |
sys.path.insert(1, os.path.join(sys.path[0], 'lvdm'))
|
7 |
|
8 |
t2v_examples = [
|
|
|
15 |
]
|
16 |
|
17 |
control_examples = [
|
18 |
+
['input/flamingo.mp4', 'An ostrich walking in the desert, photorealistic, 4k', 0, 50, 15, 1]
|
19 |
]
|
20 |
|
21 |
def videocrafter_demo(result_dir='./tmp/'):
|
22 |
text2video = Text2Video(result_dir)
|
23 |
+
videocontrol = VideoControl(result_dir)
|
24 |
with gr.Blocks(analytics_enabled=False) as videocrafter_iface:
|
25 |
gr.Markdown("<div align='center'> <h2> VideoCrafter: A Toolkit for Text-to-Video Generation and Editing </span> </h2> \
|
26 |
<a style='font-size:18px;color: #efefef' href='https://github.com/VideoCrafter/VideoCrafter'> Github </div>")
|
|
|
90 |
if __name__ == "__main__":
|
91 |
result_dir = os.path.join('./', 'results')
|
92 |
videocrafter_iface = videocrafter_demo(result_dir)
|
93 |
+
videocrafter_iface.launch()
|
94 |
+
# videocrafter_iface.launch(server_name='0.0.0.0', server_port=80)
|
demo_test.py
CHANGED
@@ -6,7 +6,7 @@ class Text2Video():
|
|
6 |
|
7 |
return '01.mp4'
|
8 |
|
9 |
-
class
|
10 |
def __init__(self) -> None:
|
11 |
pass
|
12 |
|
|
|
6 |
|
7 |
return '01.mp4'
|
8 |
|
9 |
+
class VideoControl:
|
10 |
def __init__(self) -> None:
|
11 |
pass
|
12 |
|
input/flamingo.mp4
ADDED
Binary file (897 kB). View file
|
|
videocontrol_test.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, os, sys, glob
|
2 |
+
import datetime, time
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from decord import VideoReader, cpu
|
8 |
+
import torchvision
|
9 |
+
from pytorch_lightning import seed_everything
|
10 |
+
|
11 |
+
from lvdm.samplers.ddim import DDIMSampler
|
12 |
+
from lvdm.utils.common_utils import instantiate_from_config
|
13 |
+
from lvdm.utils.saving_utils import tensor_to_mp4
|
14 |
+
from scripts.sample_text2video_adapter import load_model_checkpoint, adapter_guided_synthesis
|
15 |
+
|
16 |
+
import torchvision.transforms._transforms_video as transforms_video
|
17 |
+
from huggingface_hub import hf_hub_download
|
18 |
+
|
19 |
+
|
20 |
+
def load_video(filepath, frame_stride, video_size=(256,256), video_frames=16):
|
21 |
+
info_str = ''
|
22 |
+
vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
|
23 |
+
max_frames = len(vidreader)
|
24 |
+
# auto
|
25 |
+
|
26 |
+
if frame_stride != 0:
|
27 |
+
if frame_stride * (video_frames-1) >= max_frames:
|
28 |
+
info_str += "Warning: The user-set frame rate makes the current video length not enough, we will set it to an adaptive frame rate.\n"
|
29 |
+
frame_stride = 0
|
30 |
+
if frame_stride == 0:
|
31 |
+
frame_stride = max_frames / video_frames
|
32 |
+
# if temp_stride < 1:
|
33 |
+
# info_str = "Warning: The length of the current input video is less than 16 frames, we will automatically fill to 16 frames for you.\n"
|
34 |
+
if frame_stride > 8:
|
35 |
+
frame_stride = 8
|
36 |
+
info_str += "Warning: The current input video length is longer than 128 frames, we will process only the first 128 frames.\n"
|
37 |
+
info_str += f"Frame Stride is set to {frame_stride}"
|
38 |
+
frame_indices = [int(frame_stride*i) for i in range(video_frames)]
|
39 |
+
frames = vidreader.get_batch(frame_indices)
|
40 |
+
|
41 |
+
## [t,h,w,c] -> [c,t,h,w]
|
42 |
+
frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
|
43 |
+
frame_tensor = (frame_tensor / 255. - 0.5) * 2
|
44 |
+
return frame_tensor, info_str
|
45 |
+
|
46 |
+
class VideoControl:
|
47 |
+
def __init__(self, result_dir='./tmp/') -> None:
|
48 |
+
self.savedir = result_dir
|
49 |
+
self.download_model()
|
50 |
+
config_path = "models/adapter_t2v_depth/model_config.yaml"
|
51 |
+
ckpt_path = "models/base_t2v/model.ckpt"
|
52 |
+
adapter_ckpt = "models/adapter_t2v_depth/adapter.pth"
|
53 |
+
|
54 |
+
config = OmegaConf.load(config_path)
|
55 |
+
model_config = config.pop("model", OmegaConf.create())
|
56 |
+
model = instantiate_from_config(model_config)
|
57 |
+
model = model.to('cuda')
|
58 |
+
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
|
59 |
+
model = load_model_checkpoint(model, ckpt_path, adapter_ckpt)
|
60 |
+
model.eval()
|
61 |
+
self.model = model
|
62 |
+
self.resolution=256
|
63 |
+
self.spatial_transform = transforms_video.CenterCropVideo(self.resolution)
|
64 |
+
|
65 |
+
def get_video(self, input_video, input_prompt, frame_stride=0, vc_steps=50, vc_cfg_scale=15.0, vc_eta=1.0):
|
66 |
+
## load video
|
67 |
+
print("input video", input_video)
|
68 |
+
info_str = ''
|
69 |
+
try:
|
70 |
+
h, w, c = VideoReader(input_video, ctx=cpu(0))[0].shape
|
71 |
+
except:
|
72 |
+
os.remove(input_video)
|
73 |
+
return 'please input video', None
|
74 |
+
|
75 |
+
if h < w:
|
76 |
+
scale = h / self.resolution
|
77 |
+
else:
|
78 |
+
scale = w / self.resolution
|
79 |
+
h = math.ceil(h / scale)
|
80 |
+
w = math.ceil(w / scale)
|
81 |
+
try:
|
82 |
+
video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=16)
|
83 |
+
except:
|
84 |
+
os.remove(input_video)
|
85 |
+
return 'load video error', None
|
86 |
+
video = self.spatial_transform(video)
|
87 |
+
print('video shape', video.shape)
|
88 |
+
|
89 |
+
h, w = 32, 32
|
90 |
+
bs = 1
|
91 |
+
channels = self.model.channels
|
92 |
+
frames = self.model.temporal_length
|
93 |
+
noise_shape = [bs, channels, frames, h, w]
|
94 |
+
|
95 |
+
## inference
|
96 |
+
start = time.time()
|
97 |
+
prompt = input_prompt
|
98 |
+
video = video.unsqueeze(0).to("cuda")
|
99 |
+
with torch.no_grad():
|
100 |
+
batch_samples, batch_conds = adapter_guided_synthesis(self.model, prompt, video, noise_shape, n_samples=1, ddim_steps=vc_steps, ddim_eta=vc_eta, unconditional_guidance_scale=vc_cfg_scale)
|
101 |
+
batch_samples = batch_samples[0]
|
102 |
+
os.makedirs(self.savedir, exist_ok=True)
|
103 |
+
filename = prompt
|
104 |
+
filename = filename.replace("/", "_slash_") if "/" in filename else filename
|
105 |
+
filename = filename.replace(" ", "_") if " " in filename else filename
|
106 |
+
video_path = os.path.join(self.savedir, f'{filename}_sample.mp4')
|
107 |
+
# tensor_to_mp4(video=batch_conds.detach().cpu(), savepath=os.path.join(self.savedir, f'{filename}_depth.mp4'), fps=10)
|
108 |
+
tensor_to_mp4(video=batch_samples.detach().cpu(), savepath=os.path.join(self.savedir, f'{filename}_sample.mp4'), fps=8)
|
109 |
+
|
110 |
+
print(f"Saved in {video_path}. Time used: {(time.time() - start):.2f} seconds")
|
111 |
+
# delete video
|
112 |
+
os.remove(input_video)
|
113 |
+
return info_str, video_path
|
114 |
+
def download_model(self):
|
115 |
+
REPO_ID = 'VideoCrafter/t2v-version-1-1'
|
116 |
+
filename_list = ['models/base_t2v/model.ckpt',
|
117 |
+
"models/adapter_t2v_depth/adapter.pth"
|
118 |
+
"models/adapter_t2v_depth/dpt_hybrid-midas.pt"
|
119 |
+
]
|
120 |
+
for filename in filename_list:
|
121 |
+
if not os.path.exists(filename):
|
122 |
+
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./', local_dir_use_symlinks=False)
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
vc = VideoControl('./result')
|
131 |
+
info_str, video_path = vc.get_video('input/flamingo.mp4',"An ostrich walking in the desert, photorealistic, 4k")
|