Vlogger-ShowMaker / sample_scripts /vlog_read_script_sample.py
GrayShine's picture
Upload 60 files
2e5e07d verified
raw
history blame
14.4 kB
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import os
import sys
try:
import utils
from diffusion import create_diffusion
except:
sys.path.append(os.path.split(sys.path[0])[0])
import utils
from diffusion import create_diffusion
import argparse
import torchvision
from PIL import Image
from einops import rearrange
from models import get_models
from diffusers.models import AutoencoderKL
from models.clip import TextEmbedder
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from utils import mask_generation_before
from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from vlogger.videofusion import fusion
from vlogger.videocaption import captioning
from vlogger.videoaudio import make_audio, merge_video_audio, concatenate_videos
from vlogger.STEB.model_transform import ip_scale_set, ip_transform_model, tca_transform_model
from vlogger.planning_utils.gpt4_utils import (readscript,
readtimescript,
readprotagonistscript,
readreferencescript,
readzhscript)
def auto_inpainting(args,
video_input,
masked_video,
mask,
prompt,
image,
vae,
text_encoder,
image_encoder,
diffusion,
model,
device,
):
image_prompt_embeds = None
if prompt is None:
prompt = ""
if image is not None:
clip_image = CLIPImageProcessor()(images=image, return_tensors="pt").pixel_values
clip_image_embeds = image_encoder(clip_image.to(device)).image_embeds
uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds).to(device)
image_prompt_embeds = torch.cat([clip_image_embeds, uncond_clip_image_embeds], dim=0)
image_prompt_embeds = rearrange(image_prompt_embeds, '(b n) c -> b n c', b=2).contiguous()
model = ip_scale_set(model, args.ref_cfg_scale)
if args.use_fp16:
image_prompt_embeds = image_prompt_embeds.to(dtype=torch.float16)
b, f, c, h, w = video_input.shape
latent_h = video_input.shape[-2] // 8
latent_w = video_input.shape[-1] // 8
if args.use_fp16:
z = torch.randn(1, 4, 16, latent_h, latent_w, dtype=torch.float16, device=device) # b,c,f,h,w
masked_video = masked_video.to(dtype=torch.float16)
mask = mask.to(dtype=torch.float16)
else:
z = torch.randn(1, 4, 16, latent_h, latent_w, device=device) # b,c,f,h,w
masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
masked_video = torch.cat([masked_video] * 2)
mask = torch.cat([mask] * 2)
z = torch.cat([z] * 2)
prompt_all = [prompt] + [args.negative_prompt]
text_prompt = text_encoder(text_prompts=prompt_all, train=False)
model_kwargs = dict(encoder_hidden_states=text_prompt,
class_labels=None,
cfg_scale=args.cfg_scale,
use_fp16=args.use_fp16,
ip_hidden_states=image_prompt_embeds)
# Sample images:
samples = diffusion.ddim_sample_loop(model.forward_with_cfg,
z.shape,
z,
clip_denoised=False,
model_kwargs=model_kwargs,
progress=True,
device=device,
mask=mask,
x_start=masked_video,
use_concat=True,
)
samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
if args.use_fp16:
samples = samples.to(dtype=torch.float16)
video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
return video_clip
def main(args):
# Setup PyTorch:
if args.seed:
torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
seed_everything(args.seed)
model = get_models(args).to(device)
model = tca_transform_model(model).to(device)
model = ip_transform_model(model).to(device)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
model.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
if args.use_compile:
model = torch.compile(model)
ckpt_path = args.ckpt
state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema']
model_dict = model.state_dict()
pretrained_dict = {}
for k, v in state_dict.items():
if k in model_dict:
pretrained_dict[k] = v
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.eval() # important!
diffusion = create_diffusion(str(args.num_sampling_steps))
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device)
text_encoder = text_encoder = TextEmbedder(args.pretrained_model_path).to(device)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path).to(device)
if args.use_fp16:
print('Warnning: using half percision for inferencing!')
vae.to(dtype=torch.float16)
model.to(dtype=torch.float16)
text_encoder.to(dtype=torch.float16)
print("model ready!\n", flush=True)
# load protagonist script
character_places = readprotagonistscript(args.protagonist_file_path)
print("protagonists ready!", flush=True)
# load script
video_list = readscript(args.script_file_path)
print("video script ready!", flush=True)
# load reference script
reference_lists = readreferencescript(video_list, character_places, args.reference_file_path)
print("reference script ready!", flush=True)
# load zh script
zh_video_list = readzhscript(args.zh_script_file_path)
print("zh script ready!", flush=True)
# load time script
key_list = []
for key, value in character_places.items():
key_list.append(key)
time_list = readtimescript(args.time_file_path)
print("time script ready!", flush=True)
# generation begin
sample_list = []
for i, text_prompt in enumerate(video_list):
sample_list.append([])
for time in range(time_list[i]):
if time == 0:
print('Generating the ({}) prompt'.format(text_prompt), flush=True)
if reference_lists[i][0] == 0 or reference_lists[i][0] > len(key_list):
pil_image = None
else:
pil_image = Image.open(args.reference_image_path[reference_lists[i][0] - 1])
pil_image.resize((256, 256))
video_input = torch.zeros([1, 16, 3, args.image_size[0], args.image_size[1]]).to(device)
mask = mask_generation_before("first0", video_input.shape, video_input.dtype, device) # b,f,c,h,w
masked_video = video_input * (mask == 0)
samples = auto_inpainting(args,
video_input,
masked_video,
mask,
text_prompt,
pil_image,
vae,
text_encoder,
image_encoder,
diffusion,
model,
device,
)
sample_list[i].append(samples)
else:
if sum(video.shape[0] for video in sample_list[i]) / args.fps >= time_list[i]:
break
print('Generating the ({}) prompt'.format(text_prompt), flush=True)
if reference_lists[i][0] == 0 or reference_lists[i][0] > len(key_list):
pil_image = None
else:
pil_image = Image.open(args.reference_image_path[reference_lists[i][0] - 1])
pil_image.resize((256, 256))
pre_video = sample_list[i][-1][-args.researve_frame:]
f, c, h, w = pre_video.shape
lat_video = torch.zeros(args.num_frames - args.researve_frame, c, h, w).to(device)
video_input = torch.concat([pre_video, lat_video], dim=0)
video_input = video_input.to(device).unsqueeze(0)
mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device)
masked_video = video_input * (mask == 0)
video_clip = auto_inpainting(args,
video_input,
masked_video,
mask,
text_prompt,
pil_image,
vae,
text_encoder,
image_encoder,
diffusion,
model,
device,
)
sample_list[i].append(video_clip[args.researve_frame:])
print(video_clip[args.researve_frame:].shape)
# transition
if args.video_transition and i != 0:
video_1 = sample_list[i - 1][-1][-1:]
video_2 = sample_list[i][0][:1]
f, c, h, w = video_1.shape
video_middle = torch.zeros(args.num_frames - 2, c, h, w).to(device)
video_input = torch.concat([video_1, video_middle, video_2], dim=0)
video_input = video_input.to(device).unsqueeze(0)
mask = mask_generation_before("onelast1", video_input.shape, video_input.dtype, device)
masked_video = masked_video = video_input * (mask == 0)
video_clip = auto_inpainting(args,
video_input,
masked_video,
mask,
"smooth transition, slow motion, slow changing.",
pil_image,
vae,
text_encoder,
image_encoder,
diffusion,
model,
device,
)
sample_list[i].insert(0, video_clip[1:-1])
# save videos
samples = torch.concat(sample_list[i], dim=0)
samples = samples[0: time_list[i] * args.fps]
if not os.path.exists(args.save_origin_video_path):
os.makedirs(args.save_origin_video_path)
video_ = ((samples * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous()
torchvision.io.write_video(args.save_origin_video_path + "/" + f"{i}" + '.mp4', video_, fps=args.fps)
# post processing
fusion(args.save_origin_video_path)
captioning(args.script_file_path, args.zh_script_file_path, args.save_origin_video_path, args.save_caption_video_path)
fusion(args.save_caption_video_path)
make_audio(args.script_file_path, args.save_audio_path)
merge_video_audio(args.save_caption_video_path, args.save_audio_path, args.save_audio_caption_video_path)
concatenate_videos(args.save_audio_caption_video_path)
print('final video save path {}'.format(args.save_audio_caption_video_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/vlog_read_script_sample.yaml")
args = parser.parse_args()
omega_conf = OmegaConf.load(args.config)
save_path = omega_conf.save_path
save_origin_video_path = os.path.join(save_path, "origin_video")
save_caption_video_path = os.path.join(save_path.rsplit('/', 1)[0], "caption_video")
save_audio_path = os.path.join(save_path.rsplit('/', 1)[0], "audio")
save_audio_caption_video_path = os.path.join(save_path.rsplit('/', 1)[0], "audio_caption_video")
if omega_conf.sample_num is not None:
for i in range(omega_conf.sample_num):
omega_conf.save_origin_video_path = save_origin_video_path + f'-{i}'
omega_conf.save_caption_video_path = save_caption_video_path + f'-{i}'
omega_conf.save_audio_path = save_audio_path + f'-{i}'
omega_conf.save_audio_caption_video_path = save_audio_caption_video_path + f'-{i}'
omega_conf.seed += i
main(omega_conf)
else:
omega_conf.save_origin_video_path = save_origin_video_path
omega_conf.save_caption_video_path = save_caption_video_path
omega_conf.save_audio_path = save_audio_path
omega_conf.save_audio_caption_video_path = save_audio_caption_video_path
main(omega_conf)