Spaces:
Runtime error
Runtime error
import argparse, os, sys, glob | |
import datetime, time | |
from omegaconf import OmegaConf | |
import math | |
import torch | |
from decord import VideoReader, cpu | |
import torchvision | |
from pytorch_lightning import seed_everything | |
from lvdm.samplers.ddim import DDIMSampler | |
from lvdm.utils.common_utils import instantiate_from_config | |
from lvdm.utils.saving_utils import tensor_to_mp4 | |
from scripts.sample_text2video_adapter import load_model_checkpoint, adapter_guided_synthesis | |
import torchvision.transforms._transforms_video as transforms_video | |
from huggingface_hub import hf_hub_download | |
def load_video(filepath, frame_stride, video_size=(256,256), video_frames=16): | |
info_str = '' | |
vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0]) | |
max_frames = len(vidreader) | |
# auto | |
if frame_stride != 0: | |
if frame_stride * (video_frames-1) >= max_frames: | |
info_str += "Warning: The user-set frame rate makes the current video length not enough, we will set it to an adaptive frame rate.\n" | |
frame_stride = 0 | |
if frame_stride == 0: | |
frame_stride = max_frames / video_frames | |
# if temp_stride < 1: | |
# info_str = "Warning: The length of the current input video is less than 16 frames, we will automatically fill to 16 frames for you.\n" | |
if frame_stride > 100: | |
frame_stride = 100 | |
info_str += "Warning: The current input video length is longer than 1600 frames, we will process only the first 1600 frames.\n" | |
info_str += f"Frame Stride is set to {frame_stride}" | |
frame_indices = [int(frame_stride*i) for i in range(video_frames)] | |
frames = vidreader.get_batch(frame_indices) | |
## [t,h,w,c] -> [c,t,h,w] | |
frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() | |
frame_tensor = (frame_tensor / 255. - 0.5) * 2 | |
return frame_tensor, info_str | |
class VideoControl: | |
def __init__(self, result_dir='./tmp/') -> None: | |
self.savedir = result_dir | |
self.download_model() | |
config_path = "models/adapter_t2v_depth/model_config.yaml" | |
ckpt_path = "models/base_t2v/model_rm_wtm.ckpt" | |
adapter_ckpt = "models/adapter_t2v_depth/adapter_t2v_depth_rm_wtm.pth" | |
if os.path.exists('/dev/shm/model_rm_wtm.ckpt'): | |
ckpt_path='/dev/shm/model_rm_wtm.ckpt' | |
config = OmegaConf.load(config_path) | |
model_config = config.pop("model", OmegaConf.create()) | |
model = instantiate_from_config(model_config) | |
model = model.to('cuda') | |
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!" | |
model = load_model_checkpoint(model, ckpt_path, adapter_ckpt) | |
model.eval() | |
self.model = model | |
def get_video(self, input_video, input_prompt, frame_stride=0, vc_steps=50, vc_cfg_scale=15.0, vc_eta=1.0, video_frames=16, resolution=256): | |
torch.cuda.empty_cache() | |
if resolution > 512: | |
resolution = 512 | |
if resolution < 64: | |
resolution = 64 | |
if video_frames > 64: | |
video_frames = 64 | |
resolution = int(resolution//64)*64 | |
if vc_steps > 60: | |
vc_steps = 60 | |
## load video | |
print("input video", input_video) | |
info_str = '' | |
try: | |
h, w, c = VideoReader(input_video, ctx=cpu(0))[0].shape | |
except: | |
os.remove(input_video) | |
return 'please input video', None, None, None | |
if h > w: | |
scale = h / resolution | |
else: | |
scale = w / resolution | |
h = math.ceil(h / scale) | |
w = math.ceil(w / scale) | |
try: | |
video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=video_frames) | |
except: | |
os.remove(input_video) | |
return 'load video error', None, None, None | |
if h > w: | |
w = int(w//64)*64 | |
else: | |
h = int(h//64)*64 | |
spatial_transform = transforms_video.CenterCropVideo((h,w)) | |
video = spatial_transform(video) | |
print('video shape', video.shape) | |
rh, rw = h//8, w//8 | |
bs = 1 | |
channels = self.model.channels | |
# frames = self.model.temporal_length | |
frames = video_frames | |
noise_shape = [bs, channels, frames, rh, rw] | |
## inference | |
start = time.time() | |
prompt = input_prompt | |
video = video.unsqueeze(0).to("cuda") | |
try: | |
with torch.no_grad(): | |
batch_samples, batch_conds = adapter_guided_synthesis(self.model, prompt, video, noise_shape, n_samples=1, ddim_steps=vc_steps, ddim_eta=vc_eta, unconditional_guidance_scale=vc_cfg_scale) | |
except: | |
torch.cuda.empty_cache() | |
info_str="OOM, please enter a smaller resolution or smaller frame num" | |
return info_str, None, None, None | |
batch_samples = batch_samples[0] | |
os.makedirs(self.savedir, exist_ok=True) | |
filename = prompt | |
filename = filename.replace("/", "_slash_") if "/" in filename else filename | |
filename = filename.replace(" ", "_") if " " in filename else filename | |
if len(filename) > 200: | |
filename = filename[:200] | |
video_path = os.path.join(self.savedir, f'{filename}_sample.mp4') | |
depth_path = os.path.join(self.savedir, f'{filename}_depth.mp4') | |
origin_path = os.path.join(self.savedir, f'{filename}.mp4') | |
tensor_to_mp4(video=video.detach().cpu(), savepath=origin_path, fps=8) | |
tensor_to_mp4(video=batch_conds.detach().cpu(), savepath=depth_path, fps=8) | |
tensor_to_mp4(video=batch_samples.detach().cpu(), savepath=video_path, fps=8) | |
print(f"Saved in {video_path}. Time used: {(time.time() - start):.2f} seconds") | |
# delete video | |
(path, input_filename) = os.path.split(input_video) | |
if input_filename != 'flamingo.mp4': | |
os.remove(input_video) | |
print('delete input video') | |
# print(input_video) | |
return info_str, origin_path, depth_path, video_path | |
def download_model(self): | |
REPO_ID = 'VideoCrafter/t2v-version-1-1' | |
filename_list = ['models/base_t2v/model_rm_wtm.ckpt', | |
"models/adapter_t2v_depth/adapter_t2v_depth_rm_wtm.pth", | |
"models/adapter_t2v_depth/dpt_hybrid-midas.pt" | |
] | |
for filename in filename_list: | |
if not os.path.exists(filename): | |
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./', local_dir_use_symlinks=False) | |
if __name__ == "__main__": | |
vc = VideoControl('./result') | |
info_str, video_path = vc.get_video('input/flamingo.mp4',"An ostrich walking in the desert, photorealistic, 4k") |