import gradio as gr import os import torch import argparse import torchvision from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler) from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder from omegaconf import OmegaConf from transformers import T5EncoderModel, T5Tokenizer import os, sys sys.path.append(os.path.split(sys.path[0])[0]) from sample.pipeline_latte import LattePipeline from models import get_models # import imageio from torchvision.utils import save_image import spaces parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="./configs/t2x/t2v_sample.yaml") args = parser.parse_args() args = OmegaConf.load(args.config) torch.set_grad_enabled(False) device = "cuda" if torch.cuda.is_available() else "cpu" transformer_model = get_models(args).to(device, dtype=torch.float16) # state_dict = find_model(args.ckpt) # msg, unexp = transformer_model.load_state_dict(state_dict, strict=False) if args.enable_vae_temporal_decoder: vae = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device) else: vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", torch_dtype=torch.float16).to(device) tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # set eval mode transformer_model.eval() vae.eval() text_encoder.eval() @spaces.GPU def gen_video(text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step): torch.manual_seed(seed) if sample_method == 'DDIM': scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type, clip_sample=False) elif sample_method == 'EulerDiscrete': scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'DDPM': scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type, clip_sample=False) elif sample_method == 'DPMSolverMultistep': scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'DPMSolverSinglestep': scheduler = DPMSolverSinglestepScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'PNDM': scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'HeunDiscrete': scheduler = HeunDiscreteScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'EulerAncestralDiscrete': scheduler = EulerAncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'DEISMultistep': scheduler = DEISMultistepScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'KDPM2AncestralDiscrete': scheduler = KDPM2AncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) videogen_pipeline = LattePipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, transformer=transformer_model).to(device) # videogen_pipeline.enable_xformers_memory_efficient_attention() videos = videogen_pipeline(text_input, video_length=video_length, height=height, width=width, num_inference_steps=diffusion_step, guidance_scale=scfg_scale, enable_temporal_attentions=args.enable_temporal_attentions, num_images_per_prompt=1, mask_feature=True, enable_vae_temporal_decoder=args.enable_vae_temporal_decoder ).video save_path = args.save_img_path + 'temp' + '.mp4' torchvision.io.write_video(save_path, videos[0], fps=8) return save_path if not os.path.exists(args.save_img_path): os.makedirs(args.save_img_path) intro = """
# project page | paper #
# We will continue update Latte. # """ # ) gr.Markdown("