Spaces:
Running
on
Zero
Running
on
Zero
from pathlib import Path | |
import os | |
import sys | |
sys.path.append(str(Path(os.path.abspath('')))) | |
import torch | |
import numpy as np | |
from tools.genrl_utils import ViCLIPGlobalInstance | |
import time | |
import torchvision | |
from huggingface_hub import hf_hub_download | |
def save_videos(batch_tensors, savedir, filenames, fps=10): | |
# b,samples,c,t,h,w | |
n_samples = batch_tensors.shape[1] | |
for idx, vid_tensor in enumerate(batch_tensors): | |
video = vid_tensor.detach().cpu() | |
video = torch.clamp(video.float(), 0., 1.) | |
video = video.permute(1, 0, 2, 3, 4) # t,n,c,h,w | |
frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w] | |
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] | |
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) | |
savepath = os.path.join(savedir, f"{filenames[idx]}.mp4") | |
torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'}) | |
class Text2Video(): | |
def __init__(self,result_dir='./tmp/',gpu_num=1) -> None: | |
model_folder = str(Path(os.path.abspath('')) / 'models') | |
model_filename = 'genrl_stickman_500k_2.pt' | |
if not os.path.isfile(os.path.join(model_folder, model_filename)): | |
self.download_model(model_folder, model_filename) | |
if not os.path.isfile(os.path.join(model_folder, 'InternVideo2-stage2_1b-224p-f4.pt')): | |
self.download_internvideo2(model_folder) | |
self.agent = torch.load(os.path.join(model_folder, model_filename)) | |
model_name = 'internvideo2' | |
# Get ViCLIP | |
viclip_global_instance = ViCLIPGlobalInstance(model_name) | |
if not viclip_global_instance._instantiated: | |
print("Instantiating InternVideo2") | |
viclip_global_instance.instantiate() | |
self.clip = viclip_global_instance.viclip | |
self.tokenizer = viclip_global_instance.viclip_tokenizer | |
self.result_dir = result_dir | |
if not os.path.exists(self.result_dir): | |
os.mkdir(self.result_dir) | |
def get_prompt(self, prompt, duration): | |
torch.cuda.empty_cache() | |
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))) | |
start = time.time() | |
prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt | |
prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str | |
labels_list = [prompt_str] | |
with torch.no_grad(): | |
wm = world_model = self.agent.wm | |
connector = self.agent.wm.connector | |
decoder = world_model.heads['decoder'] | |
n_frames = connector.n_frames | |
# Get text(video) embed | |
text_feat = [] | |
for text in labels_list: | |
with torch.no_grad(): | |
text_feat.append(self.clip.get_txt_feat(text,)) | |
text_feat = torch.stack(text_feat, dim=0).to(self.clip.device) | |
video_embed = text_feat | |
B = video_embed.shape[0] | |
T = 1 | |
# Get actions | |
video_embed = video_embed.repeat(1, duration, 1) | |
with torch.no_grad(): | |
# Imagine | |
prior = wm.connector.video_imagine(video_embed, None, sample=False, reset_every_n_frames=False, denoise=True) | |
# Decode | |
prior_recon = decoder(wm.decoder_input_fn(prior))['observation'].mean + 0.5 | |
save_videos(prior_recon.unsqueeze(0), self.result_dir, filenames=[prompt_str], fps=15) | |
print(f"Saved in {prompt_str}.mp4. Time used: {(time.time() - start):.2f} seconds") | |
return os.path.join(self.result_dir, f"{prompt_str}.mp4") | |
def download_model(self, model_folder, model_filename): | |
REPO_ID = 'mazpie/genrl_models' | |
filename_list = [model_filename] | |
if not os.path.exists(model_folder): | |
os.makedirs(model_folder) | |
for filename in filename_list: | |
local_file = os.path.join(model_folder, filename) | |
if not os.path.exists(local_file): | |
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=model_folder, local_dir_use_symlinks=False) | |
def download_internvideo2(self, model_folder): | |
REPO_ID = 'OpenGVLab/InternVideo2-Stage2_1B-224p-f4' | |
filename_list = ['InternVideo2-stage2_1b-224p-f4.pt'] | |
if not os.path.exists(model_folder): | |
os.makedirs(model_folder) | |
for filename in filename_list: | |
local_file = os.path.join(model_folder, filename) | |
if not os.path.exists(local_file): | |
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=model_folder, local_dir_use_symlinks=False) | |
if __name__ == '__main__': | |
t2v = Text2Video() | |
video_path = t2v.get_prompt('a black swan swims on the pond', 8) | |
print('done', video_path) |