Spaces:
Runtime error
Runtime error
import torch | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
import os | |
import sys | |
try: | |
import utils | |
from diffusion import create_diffusion | |
except: | |
sys.path.append(os.path.split(sys.path[0])[0]) | |
import utils | |
from diffusion import create_diffusion | |
import argparse | |
import torchvision | |
from PIL import Image | |
from einops import rearrange | |
from models import get_models | |
from diffusers.models import AutoencoderKL | |
from models.clip import TextEmbedder | |
from omegaconf import OmegaConf | |
from pytorch_lightning import seed_everything | |
from utils import mask_generation_before | |
from diffusers.utils.import_utils import is_xformers_available | |
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor | |
from vlogger.videofusion import fusion | |
from vlogger.videocaption import captioning | |
from vlogger.videoaudio import make_audio, merge_video_audio, concatenate_videos | |
from vlogger.STEB.model_transform import ip_scale_set, ip_transform_model, tca_transform_model | |
from vlogger.planning_utils.gpt4_utils import (readscript, | |
readtimescript, | |
readprotagonistscript, | |
readreferencescript, | |
readzhscript) | |
def auto_inpainting(args, | |
video_input, | |
masked_video, | |
mask, | |
prompt, | |
image, | |
vae, | |
text_encoder, | |
image_encoder, | |
diffusion, | |
model, | |
device, | |
): | |
image_prompt_embeds = None | |
if prompt is None: | |
prompt = "" | |
if image is not None: | |
clip_image = CLIPImageProcessor()(images=image, return_tensors="pt").pixel_values | |
clip_image_embeds = image_encoder(clip_image.to(device)).image_embeds | |
uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds).to(device) | |
image_prompt_embeds = torch.cat([clip_image_embeds, uncond_clip_image_embeds], dim=0) | |
image_prompt_embeds = rearrange(image_prompt_embeds, '(b n) c -> b n c', b=2).contiguous() | |
model = ip_scale_set(model, args.ref_cfg_scale) | |
if args.use_fp16: | |
image_prompt_embeds = image_prompt_embeds.to(dtype=torch.float16) | |
b, f, c, h, w = video_input.shape | |
latent_h = video_input.shape[-2] // 8 | |
latent_w = video_input.shape[-1] // 8 | |
if args.use_fp16: | |
z = torch.randn(1, 4, 16, latent_h, latent_w, dtype=torch.float16, device=device) # b,c,f,h,w | |
masked_video = masked_video.to(dtype=torch.float16) | |
mask = mask.to(dtype=torch.float16) | |
else: | |
z = torch.randn(1, 4, 16, latent_h, latent_w, device=device) # b,c,f,h,w | |
masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous() | |
masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215) | |
masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous() | |
mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1) | |
masked_video = torch.cat([masked_video] * 2) | |
mask = torch.cat([mask] * 2) | |
z = torch.cat([z] * 2) | |
prompt_all = [prompt] + [args.negative_prompt] | |
text_prompt = text_encoder(text_prompts=prompt_all, train=False) | |
model_kwargs = dict(encoder_hidden_states=text_prompt, | |
class_labels=None, | |
cfg_scale=args.cfg_scale, | |
use_fp16=args.use_fp16, | |
ip_hidden_states=image_prompt_embeds) | |
# Sample images: | |
samples = diffusion.ddim_sample_loop(model.forward_with_cfg, | |
z.shape, | |
z, | |
clip_denoised=False, | |
model_kwargs=model_kwargs, | |
progress=True, | |
device=device, | |
mask=mask, | |
x_start=masked_video, | |
use_concat=True, | |
) | |
samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32] | |
if args.use_fp16: | |
samples = samples.to(dtype=torch.float16) | |
video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32] | |
video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256] | |
return video_clip | |
def main(args): | |
# Setup PyTorch: | |
if args.seed: | |
torch.manual_seed(args.seed) | |
torch.set_grad_enabled(False) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
seed_everything(args.seed) | |
model = get_models(args).to(device) | |
model = tca_transform_model(model).to(device) | |
model = ip_transform_model(model).to(device) | |
if args.enable_xformers_memory_efficient_attention: | |
if is_xformers_available(): | |
model.enable_xformers_memory_efficient_attention() | |
else: | |
raise ValueError("xformers is not available. Make sure it is installed correctly") | |
if args.use_compile: | |
model = torch.compile(model) | |
ckpt_path = args.ckpt | |
state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema'] | |
model_dict = model.state_dict() | |
pretrained_dict = {} | |
for k, v in state_dict.items(): | |
if k in model_dict: | |
pretrained_dict[k] = v | |
model_dict.update(pretrained_dict) | |
model.load_state_dict(model_dict) | |
model.eval() # important! | |
diffusion = create_diffusion(str(args.num_sampling_steps)) | |
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device) | |
text_encoder = text_encoder = TextEmbedder(args.pretrained_model_path).to(device) | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path).to(device) | |
if args.use_fp16: | |
print('Warnning: using half percision for inferencing!') | |
vae.to(dtype=torch.float16) | |
model.to(dtype=torch.float16) | |
text_encoder.to(dtype=torch.float16) | |
print("model ready!\n", flush=True) | |
# load protagonist script | |
character_places = readprotagonistscript(args.protagonist_file_path) | |
print("protagonists ready!", flush=True) | |
# load script | |
video_list = readscript(args.script_file_path) | |
print("video script ready!", flush=True) | |
# load reference script | |
reference_lists = readreferencescript(video_list, character_places, args.reference_file_path) | |
print("reference script ready!", flush=True) | |
# load zh script | |
zh_video_list = readzhscript(args.zh_script_file_path) | |
print("zh script ready!", flush=True) | |
# load time script | |
key_list = [] | |
for key, value in character_places.items(): | |
key_list.append(key) | |
time_list = readtimescript(args.time_file_path) | |
print("time script ready!", flush=True) | |
# generation begin | |
sample_list = [] | |
for i, text_prompt in enumerate(video_list): | |
sample_list.append([]) | |
for time in range(time_list[i]): | |
if time == 0: | |
print('Generating the ({}) prompt'.format(text_prompt), flush=True) | |
if reference_lists[i][0] == 0 or reference_lists[i][0] > len(key_list): | |
pil_image = None | |
else: | |
pil_image = Image.open(args.reference_image_path[reference_lists[i][0] - 1]) | |
pil_image.resize((256, 256)) | |
video_input = torch.zeros([1, 16, 3, args.image_size[0], args.image_size[1]]).to(device) | |
mask = mask_generation_before("first0", video_input.shape, video_input.dtype, device) # b,f,c,h,w | |
masked_video = video_input * (mask == 0) | |
samples = auto_inpainting(args, | |
video_input, | |
masked_video, | |
mask, | |
text_prompt, | |
pil_image, | |
vae, | |
text_encoder, | |
image_encoder, | |
diffusion, | |
model, | |
device, | |
) | |
sample_list[i].append(samples) | |
else: | |
if sum(video.shape[0] for video in sample_list[i]) / args.fps >= time_list[i]: | |
break | |
print('Generating the ({}) prompt'.format(text_prompt), flush=True) | |
if reference_lists[i][0] == 0 or reference_lists[i][0] > len(key_list): | |
pil_image = None | |
else: | |
pil_image = Image.open(args.reference_image_path[reference_lists[i][0] - 1]) | |
pil_image.resize((256, 256)) | |
pre_video = sample_list[i][-1][-args.researve_frame:] | |
f, c, h, w = pre_video.shape | |
lat_video = torch.zeros(args.num_frames - args.researve_frame, c, h, w).to(device) | |
video_input = torch.concat([pre_video, lat_video], dim=0) | |
video_input = video_input.to(device).unsqueeze(0) | |
mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) | |
masked_video = video_input * (mask == 0) | |
video_clip = auto_inpainting(args, | |
video_input, | |
masked_video, | |
mask, | |
text_prompt, | |
pil_image, | |
vae, | |
text_encoder, | |
image_encoder, | |
diffusion, | |
model, | |
device, | |
) | |
sample_list[i].append(video_clip[args.researve_frame:]) | |
print(video_clip[args.researve_frame:].shape) | |
# transition | |
if args.video_transition and i != 0: | |
video_1 = sample_list[i - 1][-1][-1:] | |
video_2 = sample_list[i][0][:1] | |
f, c, h, w = video_1.shape | |
video_middle = torch.zeros(args.num_frames - 2, c, h, w).to(device) | |
video_input = torch.concat([video_1, video_middle, video_2], dim=0) | |
video_input = video_input.to(device).unsqueeze(0) | |
mask = mask_generation_before("onelast1", video_input.shape, video_input.dtype, device) | |
masked_video = masked_video = video_input * (mask == 0) | |
video_clip = auto_inpainting(args, | |
video_input, | |
masked_video, | |
mask, | |
"smooth transition, slow motion, slow changing.", | |
pil_image, | |
vae, | |
text_encoder, | |
image_encoder, | |
diffusion, | |
model, | |
device, | |
) | |
sample_list[i].insert(0, video_clip[1:-1]) | |
# save videos | |
samples = torch.concat(sample_list[i], dim=0) | |
samples = samples[0: time_list[i] * args.fps] | |
if not os.path.exists(args.save_origin_video_path): | |
os.makedirs(args.save_origin_video_path) | |
video_ = ((samples * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous() | |
torchvision.io.write_video(args.save_origin_video_path + "/" + f"{i}" + '.mp4', video_, fps=args.fps) | |
# post processing | |
fusion(args.save_origin_video_path) | |
captioning(args.script_file_path, args.zh_script_file_path, args.save_origin_video_path, args.save_caption_video_path) | |
fusion(args.save_caption_video_path) | |
make_audio(args.script_file_path, args.save_audio_path) | |
merge_video_audio(args.save_caption_video_path, args.save_audio_path, args.save_audio_caption_video_path) | |
concatenate_videos(args.save_audio_caption_video_path) | |
print('final video save path {}'.format(args.save_audio_caption_video_path)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default="configs/vlog_read_script_sample.yaml") | |
args = parser.parse_args() | |
omega_conf = OmegaConf.load(args.config) | |
save_path = omega_conf.save_path | |
save_origin_video_path = os.path.join(save_path, "origin_video") | |
save_caption_video_path = os.path.join(save_path.rsplit('/', 1)[0], "caption_video") | |
save_audio_path = os.path.join(save_path.rsplit('/', 1)[0], "audio") | |
save_audio_caption_video_path = os.path.join(save_path.rsplit('/', 1)[0], "audio_caption_video") | |
if omega_conf.sample_num is not None: | |
for i in range(omega_conf.sample_num): | |
omega_conf.save_origin_video_path = save_origin_video_path + f'-{i}' | |
omega_conf.save_caption_video_path = save_caption_video_path + f'-{i}' | |
omega_conf.save_audio_path = save_audio_path + f'-{i}' | |
omega_conf.save_audio_caption_video_path = save_audio_caption_video_path + f'-{i}' | |
omega_conf.seed += i | |
main(omega_conf) | |
else: | |
omega_conf.save_origin_video_path = save_origin_video_path | |
omega_conf.save_caption_video_path = save_caption_video_path | |
omega_conf.save_audio_path = save_audio_path | |
omega_conf.save_audio_caption_video_path = save_audio_caption_video_path | |
main(omega_conf) | |