Spaces:
Runtime error
Runtime error
File size: 6,801 Bytes
153e804 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import argparse, os, sys, glob
import datetime, time
from omegaconf import OmegaConf
import math
import torch
from decord import VideoReader, cpu
import torchvision
from pytorch_lightning import seed_everything
from lvdm.samplers.ddim import DDIMSampler
from lvdm.utils.common_utils import instantiate_from_config
from lvdm.utils.saving_utils import tensor_to_mp4
from scripts.sample_text2video_adapter import load_model_checkpoint, adapter_guided_synthesis
import torchvision.transforms._transforms_video as transforms_video
from huggingface_hub import hf_hub_download
def load_video(filepath, frame_stride, video_size=(256,256), video_frames=16):
info_str = ''
vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
max_frames = len(vidreader)
# auto
if frame_stride != 0:
if frame_stride * (video_frames-1) >= max_frames:
info_str += "Warning: The user-set frame rate makes the current video length not enough, we will set it to an adaptive frame rate.\n"
frame_stride = 0
if frame_stride == 0:
frame_stride = max_frames / video_frames
# if temp_stride < 1:
# info_str = "Warning: The length of the current input video is less than 16 frames, we will automatically fill to 16 frames for you.\n"
if frame_stride > 100:
frame_stride = 100
info_str += "Warning: The current input video length is longer than 1600 frames, we will process only the first 1600 frames.\n"
info_str += f"Frame Stride is set to {frame_stride}"
frame_indices = [int(frame_stride*i) for i in range(video_frames)]
frames = vidreader.get_batch(frame_indices)
## [t,h,w,c] -> [c,t,h,w]
frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
frame_tensor = (frame_tensor / 255. - 0.5) * 2
return frame_tensor, info_str
class VideoControl:
def __init__(self, result_dir='./tmp/') -> None:
self.savedir = result_dir
self.download_model()
config_path = "models/adapter_t2v_depth/model_config.yaml"
ckpt_path = "models/base_t2v/model_rm_wtm.ckpt"
adapter_ckpt = "models/adapter_t2v_depth/adapter_t2v_depth_rm_wtm.pth"
if os.path.exists('/dev/shm/model_rm_wtm.ckpt'):
ckpt_path='/dev/shm/model_rm_wtm.ckpt'
config = OmegaConf.load(config_path)
model_config = config.pop("model", OmegaConf.create())
model = instantiate_from_config(model_config)
model = model.to('cuda')
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, ckpt_path, adapter_ckpt)
model.eval()
self.model = model
def get_video(self, input_video, input_prompt, frame_stride=0, vc_steps=50, vc_cfg_scale=15.0, vc_eta=1.0, video_frames=16, resolution=256):
torch.cuda.empty_cache()
if resolution > 512:
resolution = 512
if resolution < 64:
resolution = 64
if video_frames > 64:
video_frames = 64
resolution = int(resolution//64)*64
if vc_steps > 60:
vc_steps = 60
## load video
print("input video", input_video)
info_str = ''
try:
h, w, c = VideoReader(input_video, ctx=cpu(0))[0].shape
except:
os.remove(input_video)
return 'please input video', None, None, None
if h > w:
scale = h / resolution
else:
scale = w / resolution
h = math.ceil(h / scale)
w = math.ceil(w / scale)
try:
video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=video_frames)
except:
os.remove(input_video)
return 'load video error', None, None, None
if h > w:
w = int(w//64)*64
else:
h = int(h//64)*64
spatial_transform = transforms_video.CenterCropVideo((h,w))
video = spatial_transform(video)
print('video shape', video.shape)
rh, rw = h//8, w//8
bs = 1
channels = self.model.channels
# frames = self.model.temporal_length
frames = video_frames
noise_shape = [bs, channels, frames, rh, rw]
## inference
start = time.time()
prompt = input_prompt
video = video.unsqueeze(0).to("cuda")
try:
with torch.no_grad():
batch_samples, batch_conds = adapter_guided_synthesis(self.model, prompt, video, noise_shape, n_samples=1, ddim_steps=vc_steps, ddim_eta=vc_eta, unconditional_guidance_scale=vc_cfg_scale)
except:
torch.cuda.empty_cache()
info_str="OOM, please enter a smaller resolution or smaller frame num"
return info_str, None, None, None
batch_samples = batch_samples[0]
os.makedirs(self.savedir, exist_ok=True)
filename = prompt
filename = filename.replace("/", "_slash_") if "/" in filename else filename
filename = filename.replace(" ", "_") if " " in filename else filename
if len(filename) > 200:
filename = filename[:200]
video_path = os.path.join(self.savedir, f'{filename}_sample.mp4')
depth_path = os.path.join(self.savedir, f'{filename}_depth.mp4')
origin_path = os.path.join(self.savedir, f'{filename}.mp4')
tensor_to_mp4(video=video.detach().cpu(), savepath=origin_path, fps=8)
tensor_to_mp4(video=batch_conds.detach().cpu(), savepath=depth_path, fps=8)
tensor_to_mp4(video=batch_samples.detach().cpu(), savepath=video_path, fps=8)
print(f"Saved in {video_path}. Time used: {(time.time() - start):.2f} seconds")
# delete video
(path, input_filename) = os.path.split(input_video)
if input_filename != 'flamingo.mp4':
os.remove(input_video)
print('delete input video')
# print(input_video)
return info_str, origin_path, depth_path, video_path
def download_model(self):
REPO_ID = 'VideoCrafter/t2v-version-1-1'
filename_list = ['models/base_t2v/model_rm_wtm.ckpt',
"models/adapter_t2v_depth/adapter_t2v_depth_rm_wtm.pth",
"models/adapter_t2v_depth/dpt_hybrid-midas.pt"
]
for filename in filename_list:
if not os.path.exists(filename):
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./', local_dir_use_symlinks=False)
if __name__ == "__main__":
vc = VideoControl('./result')
info_str, video_path = vc.get_video('input/flamingo.mp4',"An ostrich walking in the desert, photorealistic, 4k") |