File size: 4,157 Bytes
153e804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import torch
from omegaconf import OmegaConf

from lvdm.samplers.ddim import DDIMSampler
from lvdm.utils.saving_utils import npz_to_video_grid
from scripts.sample_text2video import sample_text2video
from scripts.sample_utils import load_model
from lvdm.models.modules.lora import change_lora_v2

from huggingface_hub import hf_hub_download


def save_results(videos, save_dir, 
                 save_name="results", save_fps=8
                 ):
    save_subdir = os.path.join(save_dir, "videos")
    os.makedirs(save_subdir, exist_ok=True)
    for i in range(videos.shape[0]):
        npz_to_video_grid(videos[i:i+1,...], 
                            os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4"), 
                            fps=save_fps)
    video_path_list = [os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4") for i in range(videos.shape[0])]
    print(f'Successfully saved videos in {video_path_list[0]}')
    return video_path_list
    

class Text2Video():
    def __init__(self,result_dir='./tmp/') -> None:
        self.download_model()
        config_file = 'models/base_t2v/model_config.yaml'
        ckpt_path = 'models/base_t2v/model_rm_wtm.ckpt'
        if os.path.exists('/dev/shm/model_rm_wtm.ckpt'):
            ckpt_path='/dev/shm/model_rm_wtm.ckpt'
        config = OmegaConf.load(config_file)
        self.lora_path_list = ['','models/videolora/lora_001_Loving_Vincent_style.ckpt',
                                'models/videolora/lora_002_frozenmovie_style.ckpt',
                                'models/videolora/lora_003_MakotoShinkaiYourName_style.ckpt',
                                'models/videolora/lora_004_coco_style_v2.ckpt']
        self.lora_trigger_word_list = ['','Loving Vincent style', 'frozenmovie style', 'MakotoShinkaiYourName style', 'coco style']
        model, _, _ = load_model(config, ckpt_path, gpu_id=0, inject_lora=False)
        self.model = model
        self.last_time_lora = ''
        self.last_time_lora_scale = 1.0
        self.result_dir = result_dir
        self.save_fps = 8
        self.ddim_sampler = DDIMSampler(model) 
        self.origin_weight = None

    def get_prompt(self, input_text, steps=50, model_index=0, eta=1.0, cfg_scale=15.0, lora_scale=1.0):
        torch.cuda.empty_cache()
        if steps > 60:
            steps = 60 
        if model_index > 0:
            input_text = input_text + ', ' + self.lora_trigger_word_list[model_index]
        inject_lora = model_index > 0
        self.origin_weight = change_lora_v2(self.model, inject_lora=inject_lora, lora_scale=lora_scale, lora_path=self.lora_path_list[model_index],
                    last_time_lora=self.last_time_lora, last_time_lora_scale=self.last_time_lora_scale, origin_weight=self.origin_weight)

        all_videos = sample_text2video(self.model, input_text, n_samples=1, batch_size=1,
                        sample_type='ddim', sampler=self.ddim_sampler,
                        ddim_steps=steps, eta=eta, 
                        cfg_scale=cfg_scale,
                        )
        prompt = input_text
        prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
        prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
        self.last_time_lora=self.lora_path_list[model_index]
        self.last_time_lora_scale = lora_scale
        video_path_list = save_results(all_videos, self.result_dir, save_name=prompt_str, save_fps=self.save_fps)
        return video_path_list[0]
    
    def download_model(self):
        REPO_ID = 'VideoCrafter/t2v-version-1-1'
        filename_list = ['models/base_t2v/model_rm_wtm.ckpt',
                        'models/videolora/lora_001_Loving_Vincent_style.ckpt',
                        'models/videolora/lora_002_frozenmovie_style.ckpt',
                        'models/videolora/lora_003_MakotoShinkaiYourName_style.ckpt',
                        'models/videolora/lora_004_coco_style_v2.ckpt']
        for filename in filename_list:
            if not os.path.exists(filename):
                hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./', local_dir_use_symlinks=False)