import os import torch from omegaconf import OmegaConf from lvdm.samplers.ddim import DDIMSampler from lvdm.utils.saving_utils import npz_to_video_grid from scripts.sample_text2video import sample_text2video from scripts.sample_utils import load_model from lvdm.models.modules.lora import change_lora_v2 from huggingface_hub import hf_hub_download def save_results(videos, save_dir, save_name="results", save_fps=8 ): save_subdir = os.path.join(save_dir, "videos") os.makedirs(save_subdir, exist_ok=True) for i in range(videos.shape[0]): npz_to_video_grid(videos[i:i+1,...], os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4"), fps=save_fps) video_path_list = [os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4") for i in range(videos.shape[0])] print(f'Successfully saved videos in {video_path_list[0]}') return video_path_list class Text2Video(): def __init__(self,result_dir='./tmp/') -> None: self.download_model() config_file = 'models/base_t2v/model_config.yaml' ckpt_path = 'models/base_t2v/model_rm_wtm.ckpt' if os.path.exists('/dev/shm/model_rm_wtm.ckpt'): ckpt_path='/dev/shm/model_rm_wtm.ckpt' config = OmegaConf.load(config_file) self.lora_path_list = ['','models/videolora/lora_001_Loving_Vincent_style.ckpt', 'models/videolora/lora_002_frozenmovie_style.ckpt', 'models/videolora/lora_003_MakotoShinkaiYourName_style.ckpt', 'models/videolora/lora_004_coco_style_v2.ckpt'] self.lora_trigger_word_list = ['','Loving Vincent style', 'frozenmovie style', 'MakotoShinkaiYourName style', 'coco style'] model, _, _ = load_model(config, ckpt_path, gpu_id=0, inject_lora=False) self.model = model self.last_time_lora = '' self.last_time_lora_scale = 1.0 self.result_dir = result_dir self.save_fps = 8 self.ddim_sampler = DDIMSampler(model) self.origin_weight = None def get_prompt(self, input_text, steps=50, model_index=0, eta=1.0, cfg_scale=15.0, lora_scale=1.0): torch.cuda.empty_cache() if steps > 60: steps = 60 if model_index > 0: input_text = input_text + ', ' + self.lora_trigger_word_list[model_index] inject_lora = model_index > 0 self.origin_weight = change_lora_v2(self.model, inject_lora=inject_lora, lora_scale=lora_scale, lora_path=self.lora_path_list[model_index], last_time_lora=self.last_time_lora, last_time_lora_scale=self.last_time_lora_scale, origin_weight=self.origin_weight) all_videos = sample_text2video(self.model, input_text, n_samples=1, batch_size=1, sample_type='ddim', sampler=self.ddim_sampler, ddim_steps=steps, eta=eta, cfg_scale=cfg_scale, ) prompt = input_text prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str self.last_time_lora=self.lora_path_list[model_index] self.last_time_lora_scale = lora_scale video_path_list = save_results(all_videos, self.result_dir, save_name=prompt_str, save_fps=self.save_fps) return video_path_list[0] def download_model(self): REPO_ID = 'VideoCrafter/t2v-version-1-1' filename_list = ['models/base_t2v/model_rm_wtm.ckpt', 'models/videolora/lora_001_Loving_Vincent_style.ckpt', 'models/videolora/lora_002_frozenmovie_style.ckpt', 'models/videolora/lora_003_MakotoShinkaiYourName_style.ckpt', 'models/videolora/lora_004_coco_style_v2.ckpt'] for filename in filename_list: if not os.path.exists(filename): hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./', local_dir_use_symlinks=False)