VideoCrafter / videocontrol_test.py
imjunaidafzal's picture
Duplicate from VideoCrafter/VideoCrafter
153e804
raw
history blame contribute delete
No virus
6.8 kB
import argparse, os, sys, glob
import datetime, time
from omegaconf import OmegaConf
import math
import torch
from decord import VideoReader, cpu
import torchvision
from pytorch_lightning import seed_everything
from lvdm.samplers.ddim import DDIMSampler
from lvdm.utils.common_utils import instantiate_from_config
from lvdm.utils.saving_utils import tensor_to_mp4
from scripts.sample_text2video_adapter import load_model_checkpoint, adapter_guided_synthesis
import torchvision.transforms._transforms_video as transforms_video
from huggingface_hub import hf_hub_download
def load_video(filepath, frame_stride, video_size=(256,256), video_frames=16):
info_str = ''
vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
max_frames = len(vidreader)
# auto
if frame_stride != 0:
if frame_stride * (video_frames-1) >= max_frames:
info_str += "Warning: The user-set frame rate makes the current video length not enough, we will set it to an adaptive frame rate.\n"
frame_stride = 0
if frame_stride == 0:
frame_stride = max_frames / video_frames
# if temp_stride < 1:
# info_str = "Warning: The length of the current input video is less than 16 frames, we will automatically fill to 16 frames for you.\n"
if frame_stride > 100:
frame_stride = 100
info_str += "Warning: The current input video length is longer than 1600 frames, we will process only the first 1600 frames.\n"
info_str += f"Frame Stride is set to {frame_stride}"
frame_indices = [int(frame_stride*i) for i in range(video_frames)]
frames = vidreader.get_batch(frame_indices)
## [t,h,w,c] -> [c,t,h,w]
frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
frame_tensor = (frame_tensor / 255. - 0.5) * 2
return frame_tensor, info_str
class VideoControl:
def __init__(self, result_dir='./tmp/') -> None:
self.savedir = result_dir
self.download_model()
config_path = "models/adapter_t2v_depth/model_config.yaml"
ckpt_path = "models/base_t2v/model_rm_wtm.ckpt"
adapter_ckpt = "models/adapter_t2v_depth/adapter_t2v_depth_rm_wtm.pth"
if os.path.exists('/dev/shm/model_rm_wtm.ckpt'):
ckpt_path='/dev/shm/model_rm_wtm.ckpt'
config = OmegaConf.load(config_path)
model_config = config.pop("model", OmegaConf.create())
model = instantiate_from_config(model_config)
model = model.to('cuda')
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, ckpt_path, adapter_ckpt)
model.eval()
self.model = model
def get_video(self, input_video, input_prompt, frame_stride=0, vc_steps=50, vc_cfg_scale=15.0, vc_eta=1.0, video_frames=16, resolution=256):
torch.cuda.empty_cache()
if resolution > 512:
resolution = 512
if resolution < 64:
resolution = 64
if video_frames > 64:
video_frames = 64
resolution = int(resolution//64)*64
if vc_steps > 60:
vc_steps = 60
## load video
print("input video", input_video)
info_str = ''
try:
h, w, c = VideoReader(input_video, ctx=cpu(0))[0].shape
except:
os.remove(input_video)
return 'please input video', None, None, None
if h > w:
scale = h / resolution
else:
scale = w / resolution
h = math.ceil(h / scale)
w = math.ceil(w / scale)
try:
video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=video_frames)
except:
os.remove(input_video)
return 'load video error', None, None, None
if h > w:
w = int(w//64)*64
else:
h = int(h//64)*64
spatial_transform = transforms_video.CenterCropVideo((h,w))
video = spatial_transform(video)
print('video shape', video.shape)
rh, rw = h//8, w//8
bs = 1
channels = self.model.channels
# frames = self.model.temporal_length
frames = video_frames
noise_shape = [bs, channels, frames, rh, rw]
## inference
start = time.time()
prompt = input_prompt
video = video.unsqueeze(0).to("cuda")
try:
with torch.no_grad():
batch_samples, batch_conds = adapter_guided_synthesis(self.model, prompt, video, noise_shape, n_samples=1, ddim_steps=vc_steps, ddim_eta=vc_eta, unconditional_guidance_scale=vc_cfg_scale)
except:
torch.cuda.empty_cache()
info_str="OOM, please enter a smaller resolution or smaller frame num"
return info_str, None, None, None
batch_samples = batch_samples[0]
os.makedirs(self.savedir, exist_ok=True)
filename = prompt
filename = filename.replace("/", "_slash_") if "/" in filename else filename
filename = filename.replace(" ", "_") if " " in filename else filename
if len(filename) > 200:
filename = filename[:200]
video_path = os.path.join(self.savedir, f'{filename}_sample.mp4')
depth_path = os.path.join(self.savedir, f'{filename}_depth.mp4')
origin_path = os.path.join(self.savedir, f'{filename}.mp4')
tensor_to_mp4(video=video.detach().cpu(), savepath=origin_path, fps=8)
tensor_to_mp4(video=batch_conds.detach().cpu(), savepath=depth_path, fps=8)
tensor_to_mp4(video=batch_samples.detach().cpu(), savepath=video_path, fps=8)
print(f"Saved in {video_path}. Time used: {(time.time() - start):.2f} seconds")
# delete video
(path, input_filename) = os.path.split(input_video)
if input_filename != 'flamingo.mp4':
os.remove(input_video)
print('delete input video')
# print(input_video)
return info_str, origin_path, depth_path, video_path
def download_model(self):
REPO_ID = 'VideoCrafter/t2v-version-1-1'
filename_list = ['models/base_t2v/model_rm_wtm.ckpt',
"models/adapter_t2v_depth/adapter_t2v_depth_rm_wtm.pth",
"models/adapter_t2v_depth/dpt_hybrid-midas.pt"
]
for filename in filename_list:
if not os.path.exists(filename):
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./', local_dir_use_symlinks=False)
if __name__ == "__main__":
vc = VideoControl('./result')
info_str, video_path = vc.get_video('input/flamingo.mp4',"An ostrich walking in the desert, photorealistic, 4k")