VideoCrafter / videocontrol_test.py
imjunaidafzal's picture
Duplicate from VideoCrafter/VideoCrafter
153e804
import argparse, os, sys, glob
import datetime, time
from omegaconf import OmegaConf
import math
import torch
from decord import VideoReader, cpu
import torchvision
from pytorch_lightning import seed_everything
from lvdm.samplers.ddim import DDIMSampler
from lvdm.utils.common_utils import instantiate_from_config
from lvdm.utils.saving_utils import tensor_to_mp4
from scripts.sample_text2video_adapter import load_model_checkpoint, adapter_guided_synthesis
import torchvision.transforms._transforms_video as transforms_video
from huggingface_hub import hf_hub_download
def load_video(filepath, frame_stride, video_size=(256,256), video_frames=16):
info_str = ''
vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
max_frames = len(vidreader)
# auto
if frame_stride != 0:
if frame_stride * (video_frames-1) >= max_frames:
info_str += "Warning: The user-set frame rate makes the current video length not enough, we will set it to an adaptive frame rate.\n"
frame_stride = 0
if frame_stride == 0:
frame_stride = max_frames / video_frames
# if temp_stride < 1:
# info_str = "Warning: The length of the current input video is less than 16 frames, we will automatically fill to 16 frames for you.\n"
if frame_stride > 100:
frame_stride = 100
info_str += "Warning: The current input video length is longer than 1600 frames, we will process only the first 1600 frames.\n"
info_str += f"Frame Stride is set to {frame_stride}"
frame_indices = [int(frame_stride*i) for i in range(video_frames)]
frames = vidreader.get_batch(frame_indices)
## [t,h,w,c] -> [c,t,h,w]
frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
frame_tensor = (frame_tensor / 255. - 0.5) * 2
return frame_tensor, info_str
class VideoControl:
def __init__(self, result_dir='./tmp/') -> None:
self.savedir = result_dir
self.download_model()
config_path = "models/adapter_t2v_depth/model_config.yaml"
ckpt_path = "models/base_t2v/model_rm_wtm.ckpt"
adapter_ckpt = "models/adapter_t2v_depth/adapter_t2v_depth_rm_wtm.pth"
if os.path.exists('/dev/shm/model_rm_wtm.ckpt'):
ckpt_path='/dev/shm/model_rm_wtm.ckpt'
config = OmegaConf.load(config_path)
model_config = config.pop("model", OmegaConf.create())
model = instantiate_from_config(model_config)
model = model.to('cuda')
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, ckpt_path, adapter_ckpt)
model.eval()
self.model = model
def get_video(self, input_video, input_prompt, frame_stride=0, vc_steps=50, vc_cfg_scale=15.0, vc_eta=1.0, video_frames=16, resolution=256):
torch.cuda.empty_cache()
if resolution > 512:
resolution = 512
if resolution < 64:
resolution = 64
if video_frames > 64:
video_frames = 64
resolution = int(resolution//64)*64
if vc_steps > 60:
vc_steps = 60
## load video
print("input video", input_video)
info_str = ''
try:
h, w, c = VideoReader(input_video, ctx=cpu(0))[0].shape
except:
os.remove(input_video)
return 'please input video', None, None, None
if h > w:
scale = h / resolution
else:
scale = w / resolution
h = math.ceil(h / scale)
w = math.ceil(w / scale)
try:
video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=video_frames)
except:
os.remove(input_video)
return 'load video error', None, None, None
if h > w:
w = int(w//64)*64
else:
h = int(h//64)*64
spatial_transform = transforms_video.CenterCropVideo((h,w))
video = spatial_transform(video)
print('video shape', video.shape)
rh, rw = h//8, w//8
bs = 1
channels = self.model.channels
# frames = self.model.temporal_length
frames = video_frames
noise_shape = [bs, channels, frames, rh, rw]
## inference
start = time.time()
prompt = input_prompt
video = video.unsqueeze(0).to("cuda")
try:
with torch.no_grad():
batch_samples, batch_conds = adapter_guided_synthesis(self.model, prompt, video, noise_shape, n_samples=1, ddim_steps=vc_steps, ddim_eta=vc_eta, unconditional_guidance_scale=vc_cfg_scale)
except:
torch.cuda.empty_cache()
info_str="OOM, please enter a smaller resolution or smaller frame num"
return info_str, None, None, None
batch_samples = batch_samples[0]
os.makedirs(self.savedir, exist_ok=True)
filename = prompt
filename = filename.replace("/", "_slash_") if "/" in filename else filename
filename = filename.replace(" ", "_") if " " in filename else filename
if len(filename) > 200:
filename = filename[:200]
video_path = os.path.join(self.savedir, f'{filename}_sample.mp4')
depth_path = os.path.join(self.savedir, f'{filename}_depth.mp4')
origin_path = os.path.join(self.savedir, f'{filename}.mp4')
tensor_to_mp4(video=video.detach().cpu(), savepath=origin_path, fps=8)
tensor_to_mp4(video=batch_conds.detach().cpu(), savepath=depth_path, fps=8)
tensor_to_mp4(video=batch_samples.detach().cpu(), savepath=video_path, fps=8)
print(f"Saved in {video_path}. Time used: {(time.time() - start):.2f} seconds")
# delete video
(path, input_filename) = os.path.split(input_video)
if input_filename != 'flamingo.mp4':
os.remove(input_video)
print('delete input video')
# print(input_video)
return info_str, origin_path, depth_path, video_path
def download_model(self):
REPO_ID = 'VideoCrafter/t2v-version-1-1'
filename_list = ['models/base_t2v/model_rm_wtm.ckpt',
"models/adapter_t2v_depth/adapter_t2v_depth_rm_wtm.pth",
"models/adapter_t2v_depth/dpt_hybrid-midas.pt"
]
for filename in filename_list:
if not os.path.exists(filename):
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./', local_dir_use_symlinks=False)
if __name__ == "__main__":
vc = VideoControl('./result')
info_str, video_path = vc.get_video('input/flamingo.mp4',"An ostrich walking in the desert, photorealistic, 4k")