genrl / demo /t2v.py
mazpie's picture
Initial commit
2d9a728
raw
history blame
4.94 kB
from pathlib import Path
import os
import sys
sys.path.append(str(Path(os.path.abspath(''))))
import torch
import numpy as np
from tools.genrl_utils import ViCLIPGlobalInstance
import time
import torchvision
from huggingface_hub import hf_hub_download
def save_videos(batch_tensors, savedir, filenames, fps=10):
# b,samples,c,t,h,w
n_samples = batch_tensors.shape[1]
for idx, vid_tensor in enumerate(batch_tensors):
video = vid_tensor.detach().cpu()
video = torch.clamp(video.float(), 0., 1.)
video = video.permute(1, 0, 2, 3, 4) # t,n,c,h,w
frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w]
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
class Text2Video():
def __init__(self,result_dir='./tmp/',gpu_num=1) -> None:
model_folder = str(Path(os.path.abspath('')) / 'models')
model_filename = 'genrl_stickman_500k_2.pt'
if not os.path.isfile(os.path.join(model_folder, model_filename)):
self.download_model(model_folder, model_filename)
if not os.path.isfile(os.path.join(model_folder, 'InternVideo2-stage2_1b-224p-f4.pt')):
self.download_internvideo2(model_folder)
self.agent = torch.load(os.path.join(model_folder, model_filename))
model_name = 'internvideo2'
# Get ViCLIP
viclip_global_instance = ViCLIPGlobalInstance(model_name)
if not viclip_global_instance._instantiated:
print("Instantiating InternVideo2")
viclip_global_instance.instantiate()
self.clip = viclip_global_instance.viclip
self.tokenizer = viclip_global_instance.viclip_tokenizer
self.result_dir = result_dir
if not os.path.exists(self.result_dir):
os.mkdir(self.result_dir)
def get_prompt(self, prompt, duration):
torch.cuda.empty_cache()
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
start = time.time()
prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
labels_list = [prompt_str]
with torch.no_grad():
wm = world_model = self.agent.wm
connector = self.agent.wm.connector
decoder = world_model.heads['decoder']
n_frames = connector.n_frames
# Get text(video) embed
text_feat = []
for text in labels_list:
with torch.no_grad():
text_feat.append(self.clip.get_txt_feat(text,))
text_feat = torch.stack(text_feat, dim=0).to(self.clip.device)
video_embed = text_feat
B = video_embed.shape[0]
T = 1
# Get actions
video_embed = video_embed.repeat(1, duration, 1)
with torch.no_grad():
# Imagine
prior = wm.connector.video_imagine(video_embed, None, sample=False, reset_every_n_frames=False, denoise=True)
# Decode
prior_recon = decoder(wm.decoder_input_fn(prior))['observation'].mean + 0.5
save_videos(prior_recon.unsqueeze(0), self.result_dir, filenames=[prompt_str], fps=15)
print(f"Saved in {prompt_str}.mp4. Time used: {(time.time() - start):.2f} seconds")
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
def download_model(self, model_folder, model_filename):
REPO_ID = 'mazpie/genrl_models'
filename_list = [model_filename]
if not os.path.exists(model_folder):
os.makedirs(model_folder)
for filename in filename_list:
local_file = os.path.join(model_folder, filename)
if not os.path.exists(local_file):
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=model_folder, local_dir_use_symlinks=False)
def download_internvideo2(self, model_folder):
REPO_ID = 'OpenGVLab/InternVideo2-Stage2_1B-224p-f4'
filename_list = ['InternVideo2-stage2_1b-224p-f4.pt']
if not os.path.exists(model_folder):
os.makedirs(model_folder)
for filename in filename_list:
local_file = os.path.join(model_folder, filename)
if not os.path.exists(local_file):
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=model_folder, local_dir_use_symlinks=False)
if __name__ == '__main__':
t2v = Text2Video()
video_path = t2v.get_prompt('a black swan swims on the pond', 8)
print('done', video_path)