Haoxin Chen commited on
Commit
0757c55
·
1 Parent(s): 1e5bda6

add videocontrol code

Browse files
Files changed (5) hide show
  1. .gitignore +4 -1
  2. app.py +6 -4
  3. demo_test.py +1 -1
  4. input/flamingo.mp4 +0 -0
  5. videocontrol_test.py +131 -0
.gitignore CHANGED
@@ -1,7 +1,10 @@
1
  .DS_Store
2
  *pyc
 
3
  __pycache__
4
  *.egg-info
5
 
6
  results
7
- *.ckpt
 
 
 
1
  .DS_Store
2
  *pyc
3
+ .vscode
4
  __pycache__
5
  *.egg-info
6
 
7
  results
8
+ *.ckpt
9
+ *.pt
10
+ *.pth
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
  import sys
3
  import gradio as gr
4
- from demo_test import Text2Video, VideoContorl
 
5
  sys.path.insert(1, os.path.join(sys.path[0], 'lvdm'))
6
 
7
  t2v_examples = [
@@ -14,12 +15,12 @@ t2v_examples = [
14
  ]
15
 
16
  control_examples = [
17
- ['01.mp4', 'a dog', 0, 50, 15, 1]
18
  ]
19
 
20
  def videocrafter_demo(result_dir='./tmp/'):
21
  text2video = Text2Video(result_dir)
22
- videocontrol = VideoContorl()
23
  with gr.Blocks(analytics_enabled=False) as videocrafter_iface:
24
  gr.Markdown("<div align='center'> <h2> VideoCrafter: A Toolkit for Text-to-Video Generation and Editing </span> </h2> \
25
  <a style='font-size:18px;color: #efefef' href='https://github.com/VideoCrafter/VideoCrafter'> Github </div>")
@@ -89,4 +90,5 @@ def videocrafter_demo(result_dir='./tmp/'):
89
  if __name__ == "__main__":
90
  result_dir = os.path.join('./', 'results')
91
  videocrafter_iface = videocrafter_demo(result_dir)
92
- videocrafter_iface.launch(server_name='0.0.0.0', server_port=80)
 
 
1
  import os
2
  import sys
3
  import gradio as gr
4
+ from videocrafter_test import Text2Video
5
+ from videocontrol_test import VideoControl
6
  sys.path.insert(1, os.path.join(sys.path[0], 'lvdm'))
7
 
8
  t2v_examples = [
 
15
  ]
16
 
17
  control_examples = [
18
+ ['input/flamingo.mp4', 'An ostrich walking in the desert, photorealistic, 4k', 0, 50, 15, 1]
19
  ]
20
 
21
  def videocrafter_demo(result_dir='./tmp/'):
22
  text2video = Text2Video(result_dir)
23
+ videocontrol = VideoControl(result_dir)
24
  with gr.Blocks(analytics_enabled=False) as videocrafter_iface:
25
  gr.Markdown("<div align='center'> <h2> VideoCrafter: A Toolkit for Text-to-Video Generation and Editing </span> </h2> \
26
  <a style='font-size:18px;color: #efefef' href='https://github.com/VideoCrafter/VideoCrafter'> Github </div>")
 
90
  if __name__ == "__main__":
91
  result_dir = os.path.join('./', 'results')
92
  videocrafter_iface = videocrafter_demo(result_dir)
93
+ videocrafter_iface.launch()
94
+ # videocrafter_iface.launch(server_name='0.0.0.0', server_port=80)
demo_test.py CHANGED
@@ -6,7 +6,7 @@ class Text2Video():
6
 
7
  return '01.mp4'
8
 
9
- class VideoContorl:
10
  def __init__(self) -> None:
11
  pass
12
 
 
6
 
7
  return '01.mp4'
8
 
9
+ class VideoControl:
10
  def __init__(self) -> None:
11
  pass
12
 
input/flamingo.mp4 ADDED
Binary file (897 kB). View file
 
videocontrol_test.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob
2
+ import datetime, time
3
+ from omegaconf import OmegaConf
4
+ import math
5
+
6
+ import torch
7
+ from decord import VideoReader, cpu
8
+ import torchvision
9
+ from pytorch_lightning import seed_everything
10
+
11
+ from lvdm.samplers.ddim import DDIMSampler
12
+ from lvdm.utils.common_utils import instantiate_from_config
13
+ from lvdm.utils.saving_utils import tensor_to_mp4
14
+ from scripts.sample_text2video_adapter import load_model_checkpoint, adapter_guided_synthesis
15
+
16
+ import torchvision.transforms._transforms_video as transforms_video
17
+ from huggingface_hub import hf_hub_download
18
+
19
+
20
+ def load_video(filepath, frame_stride, video_size=(256,256), video_frames=16):
21
+ info_str = ''
22
+ vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
23
+ max_frames = len(vidreader)
24
+ # auto
25
+
26
+ if frame_stride != 0:
27
+ if frame_stride * (video_frames-1) >= max_frames:
28
+ info_str += "Warning: The user-set frame rate makes the current video length not enough, we will set it to an adaptive frame rate.\n"
29
+ frame_stride = 0
30
+ if frame_stride == 0:
31
+ frame_stride = max_frames / video_frames
32
+ # if temp_stride < 1:
33
+ # info_str = "Warning: The length of the current input video is less than 16 frames, we will automatically fill to 16 frames for you.\n"
34
+ if frame_stride > 8:
35
+ frame_stride = 8
36
+ info_str += "Warning: The current input video length is longer than 128 frames, we will process only the first 128 frames.\n"
37
+ info_str += f"Frame Stride is set to {frame_stride}"
38
+ frame_indices = [int(frame_stride*i) for i in range(video_frames)]
39
+ frames = vidreader.get_batch(frame_indices)
40
+
41
+ ## [t,h,w,c] -> [c,t,h,w]
42
+ frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
43
+ frame_tensor = (frame_tensor / 255. - 0.5) * 2
44
+ return frame_tensor, info_str
45
+
46
+ class VideoControl:
47
+ def __init__(self, result_dir='./tmp/') -> None:
48
+ self.savedir = result_dir
49
+ self.download_model()
50
+ config_path = "models/adapter_t2v_depth/model_config.yaml"
51
+ ckpt_path = "models/base_t2v/model.ckpt"
52
+ adapter_ckpt = "models/adapter_t2v_depth/adapter.pth"
53
+
54
+ config = OmegaConf.load(config_path)
55
+ model_config = config.pop("model", OmegaConf.create())
56
+ model = instantiate_from_config(model_config)
57
+ model = model.to('cuda')
58
+ assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
59
+ model = load_model_checkpoint(model, ckpt_path, adapter_ckpt)
60
+ model.eval()
61
+ self.model = model
62
+ self.resolution=256
63
+ self.spatial_transform = transforms_video.CenterCropVideo(self.resolution)
64
+
65
+ def get_video(self, input_video, input_prompt, frame_stride=0, vc_steps=50, vc_cfg_scale=15.0, vc_eta=1.0):
66
+ ## load video
67
+ print("input video", input_video)
68
+ info_str = ''
69
+ try:
70
+ h, w, c = VideoReader(input_video, ctx=cpu(0))[0].shape
71
+ except:
72
+ os.remove(input_video)
73
+ return 'please input video', None
74
+
75
+ if h < w:
76
+ scale = h / self.resolution
77
+ else:
78
+ scale = w / self.resolution
79
+ h = math.ceil(h / scale)
80
+ w = math.ceil(w / scale)
81
+ try:
82
+ video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=16)
83
+ except:
84
+ os.remove(input_video)
85
+ return 'load video error', None
86
+ video = self.spatial_transform(video)
87
+ print('video shape', video.shape)
88
+
89
+ h, w = 32, 32
90
+ bs = 1
91
+ channels = self.model.channels
92
+ frames = self.model.temporal_length
93
+ noise_shape = [bs, channels, frames, h, w]
94
+
95
+ ## inference
96
+ start = time.time()
97
+ prompt = input_prompt
98
+ video = video.unsqueeze(0).to("cuda")
99
+ with torch.no_grad():
100
+ batch_samples, batch_conds = adapter_guided_synthesis(self.model, prompt, video, noise_shape, n_samples=1, ddim_steps=vc_steps, ddim_eta=vc_eta, unconditional_guidance_scale=vc_cfg_scale)
101
+ batch_samples = batch_samples[0]
102
+ os.makedirs(self.savedir, exist_ok=True)
103
+ filename = prompt
104
+ filename = filename.replace("/", "_slash_") if "/" in filename else filename
105
+ filename = filename.replace(" ", "_") if " " in filename else filename
106
+ video_path = os.path.join(self.savedir, f'{filename}_sample.mp4')
107
+ # tensor_to_mp4(video=batch_conds.detach().cpu(), savepath=os.path.join(self.savedir, f'{filename}_depth.mp4'), fps=10)
108
+ tensor_to_mp4(video=batch_samples.detach().cpu(), savepath=os.path.join(self.savedir, f'{filename}_sample.mp4'), fps=8)
109
+
110
+ print(f"Saved in {video_path}. Time used: {(time.time() - start):.2f} seconds")
111
+ # delete video
112
+ os.remove(input_video)
113
+ return info_str, video_path
114
+ def download_model(self):
115
+ REPO_ID = 'VideoCrafter/t2v-version-1-1'
116
+ filename_list = ['models/base_t2v/model.ckpt',
117
+ "models/adapter_t2v_depth/adapter.pth"
118
+ "models/adapter_t2v_depth/dpt_hybrid-midas.pt"
119
+ ]
120
+ for filename in filename_list:
121
+ if not os.path.exists(filename):
122
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./', local_dir_use_symlinks=False)
123
+
124
+
125
+
126
+
127
+
128
+
129
+ if __name__ == "__main__":
130
+ vc = VideoControl('./result')
131
+ info_str, video_path = vc.get_video('input/flamingo.mp4',"An ostrich walking in the desert, photorealistic, 4k")