File size: 5,123 Bytes
2d9a728
 
 
 
 
 
 
 
 
 
 
 
 
ba39165
ae407cb
 
 
2d9a728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f83208c
2d9a728
 
 
 
 
 
f83208c
2d9a728
 
 
 
 
 
 
e6f6d44
 
 
f83208c
2d9a728
 
a8f4e91
2d9a728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cb3f2f
2d9a728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae407cb
2d9a728
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from pathlib import Path 
import os
import sys
sys.path.append(str(Path(os.path.abspath(''))))

import torch
import numpy as np
from tools.genrl_utils import ViCLIPGlobalInstance

import time
import torchvision
from huggingface_hub import hf_hub_download

import spaces
# IMPORT HF_TOKEN
hf_token = os.environ['HF_TOKEN']

def save_videos(batch_tensors, savedir, filenames, fps=10):
    # b,samples,c,t,h,w
    n_samples = batch_tensors.shape[1]
    for idx, vid_tensor in enumerate(batch_tensors):
        video = vid_tensor.detach().cpu()
        video = torch.clamp(video.float(), 0., 1.)
        video = video.permute(1, 0, 2, 3, 4) # t,n,c,h,w
        frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w]
        grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
        grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
        savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
        torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})

class Text2Video():
    def __init__(self,result_dir='./tmp/',gpu_num=1) -> None:
        model_folder = str(Path(os.path.abspath('')) / 'models')
        model_filename = 'genrl_stickman_500k_2.pt'
        
        if not os.path.isfile(os.path.join(model_folder, model_filename)):
            self.download_model(model_folder, model_filename)
        if not os.path.isfile(os.path.join(model_folder, 'InternVideo2-stage2_1b-224p-f4.pt')):
            self.download_internvideo2(model_folder)
        self.agent = torch.load(os.path.join(model_folder, model_filename),map_location='cpu')
        model_name = 'internvideo2'

        # Get ViCLIP
        viclip_global_instance = ViCLIPGlobalInstance(model_name)
        if not viclip_global_instance._instantiated:
            print("Instantiating InternVideo2")
            viclip_global_instance.instantiate(device='cpu')
        self.clip = viclip_global_instance.viclip
        self.tokenizer = viclip_global_instance.viclip_tokenizer

        self.result_dir = result_dir
        if not os.path.exists(self.result_dir):
            os.mkdir(self.result_dir)

        self.agent.to('cuda')
        self.clip.to('cuda')

    @spaces.GPU
    def get_prompt(self, prompt, duration):
        torch.cuda.empty_cache()

        print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
        start = time.time()

        prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
        prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str

        labels_list = [prompt_str]
        with torch.no_grad():
            wm = world_model = self.agent.wm
            connector = self.agent.wm.connector
            decoder = world_model.heads['decoder']
            n_frames = connector.n_frames
            
            # Get text(video) embed
            text_feat = []
            for text in labels_list:
                with torch.no_grad():
                    text_feat.append(self.clip.get_txt_feat(text,))
            text_feat = torch.stack(text_feat, dim=0).to('cuda')

            video_embed = text_feat

            B = video_embed.shape[0]
            T = 1

            # Get actions
            video_embed = video_embed.repeat(1, duration, 1)
            with torch.no_grad():
                # Imagine
                prior = wm.connector.video_imagine(video_embed, None, sample=False, reset_every_n_frames=False, denoise=True)
                # Decode
                prior_recon = decoder(wm.decoder_input_fn(prior))['observation'].mean + 0.5

        save_videos(prior_recon.unsqueeze(0), self.result_dir, filenames=[prompt_str], fps=15)
        print(f"Saved in {prompt_str}.mp4. Time used: {(time.time() - start):.2f} seconds")
        return os.path.join(self.result_dir, f"{prompt_str}.mp4")
    
    def download_model(self, model_folder, model_filename):
        REPO_ID = 'mazpie/genrl_models'
        filename_list = [model_filename]
        if not os.path.exists(model_folder):
            os.makedirs(model_folder)
        for filename in filename_list:
            local_file = os.path.join(model_folder, filename)

            if not os.path.exists(local_file):
                hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=model_folder, local_dir_use_symlinks=False)
    
    def download_internvideo2(self, model_folder):
        REPO_ID = 'OpenGVLab/InternVideo2-Stage2_1B-224p-f4'
        filename_list = ['InternVideo2-stage2_1b-224p-f4.pt']
        if not os.path.exists(model_folder):
            os.makedirs(model_folder)
        for filename in filename_list:
            local_file = os.path.join(model_folder, filename)

            if not os.path.exists(local_file):
                hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=model_folder, local_dir_use_symlinks=False, token=hf_token)

if __name__ == '__main__':
    t2v = Text2Video()
    video_path = t2v.get_prompt('a black swan swims on the pond', 8)
    print('done', video_path)