import gradio as gr import os import torch import argparse import torchvision from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler) from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder from omegaconf import OmegaConf from transformers import T5EncoderModel, T5Tokenizer import os, sys sys.path.append(os.path.split(sys.path[0])[0]) from sample.pipeline_latte import LattePipeline from models import get_models # import imageio from torchvision.utils import save_image import spaces parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="./configs/t2x/t2v_sample.yaml") args = parser.parse_args() args = OmegaConf.load(args.config) torch.set_grad_enabled(False) device = "cuda" if torch.cuda.is_available() else "cpu" transformer_model = get_models(args).to(device, dtype=torch.float16) # state_dict = find_model(args.ckpt) # msg, unexp = transformer_model.load_state_dict(state_dict, strict=False) if args.enable_vae_temporal_decoder: vae = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device) else: vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", torch_dtype=torch.float16).to(device) tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # set eval mode transformer_model.eval() vae.eval() text_encoder.eval() @spaces.GPU def gen_video(text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step): torch.manual_seed(seed) if sample_method == 'DDIM': scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type, clip_sample=False) elif sample_method == 'EulerDiscrete': scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'DDPM': scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type, clip_sample=False) elif sample_method == 'DPMSolverMultistep': scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'DPMSolverSinglestep': scheduler = DPMSolverSinglestepScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'PNDM': scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'HeunDiscrete': scheduler = HeunDiscreteScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'EulerAncestralDiscrete': scheduler = EulerAncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'DEISMultistep': scheduler = DEISMultistepScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'KDPM2AncestralDiscrete': scheduler = KDPM2AncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) videogen_pipeline = LattePipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, transformer=transformer_model).to(device) # videogen_pipeline.enable_xformers_memory_efficient_attention() videos = videogen_pipeline(text_input, video_length=video_length, height=height, width=width, num_inference_steps=diffusion_step, guidance_scale=scfg_scale, enable_temporal_attentions=args.enable_temporal_attentions, num_images_per_prompt=1, mask_feature=True, enable_vae_temporal_decoder=args.enable_vae_temporal_decoder ).video save_path = args.save_img_path + 'temp' + '.mp4' torchvision.io.write_video(save_path, videos[0], fps=8) return save_path if not os.path.exists(args.save_img_path): os.makedirs(args.save_img_path) intro = """

Latte: Latent Diffusion Transformer for Video Generation

""" with gr.Blocks() as demo: # gr.HTML(intro) # with gr.Accordion("README", open=False): # gr.HTML( # """ #

# project page | paper #

# We will continue update Latte. # """ # ) gr.Markdown("
Latte: Latent Diffusion Transformer for Video Generation
") gr.Markdown( """

Latte supports both T2I and T2V, and will be continuously updated, so stay tuned!

""" ) gr.Markdown( """
[Arxiv Report] | [Project Page] | [Github]
""" ) with gr.Row(): with gr.Column(visible=True) as input_raws: with gr.Row(): with gr.Column(scale=1.0): # text_input = gr.Textbox(show_label=True, interactive=True, label="Text prompt").style(container=False) text_input = gr.Textbox(show_label=True, interactive=True, label="Prompt") # with gr.Row(): # with gr.Column(scale=0.5): # image_input = gr.Image(show_label=True, interactive=True, label="Reference image").style(container=False) # with gr.Column(scale=0.5): # preframe_input = gr.Image(show_label=True, interactive=True, label="First frame").style(container=False) with gr.Row(): with gr.Column(scale=0.5): sample_method = gr.Dropdown(choices=["DDIM", "EulerDiscrete", "PNDM"], label="Sample Method", value="DDIM") # with gr.Row(): # with gr.Column(scale=1.0): # video_length = gr.Slider( # minimum=1, # maximum=24, # value=1, # step=1, # interactive=True, # label="Video Length (1 for T2I and 16 for T2V)", # ) with gr.Column(scale=0.5): video_length = gr.Dropdown(choices=[1, 16], label="Video Length (1 for T2I and 16 for T2V)", value=16) with gr.Row(): with gr.Column(scale=1.0): scfg_scale = gr.Slider( minimum=1, maximum=50, value=7.5, step=0.1, interactive=True, label="Guidence Scale", ) with gr.Row(): with gr.Column(scale=1.0): seed = gr.Slider( minimum=1, maximum=2147483647, value=100, step=1, interactive=True, label="Seed", ) with gr.Row(): with gr.Column(scale=0.5): height = gr.Slider( minimum=256, maximum=768, value=512, step=16, interactive=False, label="Height", ) # with gr.Row(): with gr.Column(scale=0.5): width = gr.Slider( minimum=256, maximum=768, value=512, step=16, interactive=False, label="Width", ) with gr.Row(): with gr.Column(scale=1.0): diffusion_step = gr.Slider( minimum=20, maximum=250, value=50, step=1, interactive=True, label="Sampling Step", ) with gr.Column(scale=0.6, visible=True) as video_upload: # with gr.Column(visible=True) as video_upload: output = gr.Video(interactive=False, include_audio=True, elem_id="输出的视频") #.style(height=360) # with gr.Column(elem_id="image", scale=0.5) as img_part: # with gr.Tab("Video", elem_id='video_tab'): # with gr.Tab("Image", elem_id='image_tab'): # up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload").style(height=360) # upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") # clear = gr.Button("Restart") with gr.Row(): with gr.Column(scale=1.0, min_width=0): run = gr.Button("💭Run") # with gr.Column(scale=0.5, min_width=0): # clear = gr.Button("🔄Clear️") run.click(gen_video, [text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step], [output]) demo.launch(debug=False, share=True) # demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)