Latte-1 / demo.py
maxin-cn's picture
Upload folder using huggingface_hub
94bafa8 verified
raw
history blame
15 kB
import gradio as gr
import os
import torch
import argparse
import torchvision
from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler,
EulerDiscreteScheduler, DPMSolverMultistepScheduler,
HeunDiscreteScheduler, EulerAncestralDiscreteScheduler,
DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler)
from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from omegaconf import OmegaConf
from transformers import T5EncoderModel, T5Tokenizer
import os, sys
sys.path.append(os.path.split(sys.path[0])[0])
from sample.pipeline_latte import LattePipeline
from models import get_models
# import imageio
from torchvision.utils import save_image
import spaces
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/t2x/t2v_sample.yaml")
args = parser.parse_args()
args = OmegaConf.load(args.config)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
transformer_model = get_models(args).to(device, dtype=torch.float16)
# state_dict = find_model(args.ckpt)
# msg, unexp = transformer_model.load_state_dict(state_dict, strict=False)
if args.enable_vae_temporal_decoder:
vae = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
else:
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", torch_dtype=torch.float16).to(device)
tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device)
# set eval mode
transformer_model.eval()
vae.eval()
text_encoder.eval()
@spaces.GPU
def gen_video(text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step):
torch.manual_seed(seed)
if sample_method == 'DDIM':
scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
clip_sample=False)
elif sample_method == 'EulerDiscrete':
scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'DDPM':
scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
clip_sample=False)
elif sample_method == 'DPMSolverMultistep':
scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'DPMSolverSinglestep':
scheduler = DPMSolverSinglestepScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'PNDM':
scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'HeunDiscrete':
scheduler = HeunDiscreteScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'EulerAncestralDiscrete':
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'DEISMultistep':
scheduler = DEISMultistepScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'KDPM2AncestralDiscrete':
scheduler = KDPM2AncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
videogen_pipeline = LattePipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer_model).to(device)
# videogen_pipeline.enable_xformers_memory_efficient_attention()
videos = videogen_pipeline(text_input,
video_length=video_length,
height=height,
width=width,
num_inference_steps=diffusion_step,
guidance_scale=scfg_scale,
enable_temporal_attentions=args.enable_temporal_attentions,
num_images_per_prompt=1,
mask_feature=True,
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder
).video
save_path = args.save_img_path + 'temp' + '.mp4'
torchvision.io.write_video(save_path, videos[0], fps=8)
return save_path
if not os.path.exists(args.save_img_path):
os.makedirs(args.save_img_path)
intro = """
<div style="display: flex;align-items: center;justify-content: center">
<h1 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Latte: Latent Diffusion Transformer for Video Generation</h1>
</div>
"""
with gr.Blocks() as demo:
# gr.HTML(intro)
# with gr.Accordion("README", open=False):
# gr.HTML(
# """
# <p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
# <a href="https://maxin-cn.github.io/latte_project/" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2401.03048" target="_blank">paper</a>
# </p>
# We will continue update Latte.
# """
# )
gr.Markdown("<font color=red size=10><center>Latte: Latent Diffusion Transformer for Video Generation</center></font>")
gr.Markdown(
"""<div style="display: flex;align-items: center;justify-content: center">
<h2 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Latte supports both T2I and T2V, and will be continuously updated, so stay tuned!</h2></div>
"""
)
gr.Markdown(
"""<div style="display: flex;align-items: center;justify-content: center">
[<a href="https://arxiv.org/abs/2401.03048">Arxiv Report</a>] | [<a href="https://maxin-cn.github.io/latte_project/">Project Page</a>] | [<a href="https://github.com/Vchitect/Latte">Github</a>]</div>
"""
)
with gr.Row():
with gr.Column(visible=True) as input_raws:
with gr.Row():
with gr.Column(scale=1.0):
# text_input = gr.Textbox(show_label=True, interactive=True, label="Text prompt").style(container=False)
text_input = gr.Textbox(show_label=True, interactive=True, label="Prompt")
# with gr.Row():
# with gr.Column(scale=0.5):
# image_input = gr.Image(show_label=True, interactive=True, label="Reference image").style(container=False)
# with gr.Column(scale=0.5):
# preframe_input = gr.Image(show_label=True, interactive=True, label="First frame").style(container=False)
with gr.Row():
with gr.Column(scale=0.5):
sample_method = gr.Dropdown(choices=["DDIM", "EulerDiscrete", "PNDM"], label="Sample Method", value="DDIM")
# with gr.Row():
# with gr.Column(scale=1.0):
# video_length = gr.Slider(
# minimum=1,
# maximum=24,
# value=1,
# step=1,
# interactive=True,
# label="Video Length (1 for T2I and 16 for T2V)",
# )
with gr.Column(scale=0.5):
video_length = gr.Dropdown(choices=[1, 16], label="Video Length (1 for T2I and 16 for T2V)", value=16)
with gr.Row():
with gr.Column(scale=1.0):
scfg_scale = gr.Slider(
minimum=1,
maximum=50,
value=7.5,
step=0.1,
interactive=True,
label="Guidence Scale",
)
with gr.Row():
with gr.Column(scale=1.0):
seed = gr.Slider(
minimum=1,
maximum=2147483647,
value=100,
step=1,
interactive=True,
label="Seed",
)
with gr.Row():
with gr.Column(scale=0.5):
height = gr.Slider(
minimum=256,
maximum=768,
value=512,
step=16,
interactive=False,
label="Height",
)
# with gr.Row():
with gr.Column(scale=0.5):
width = gr.Slider(
minimum=256,
maximum=768,
value=512,
step=16,
interactive=False,
label="Width",
)
with gr.Row():
with gr.Column(scale=1.0):
diffusion_step = gr.Slider(
minimum=20,
maximum=250,
value=50,
step=1,
interactive=True,
label="Sampling Step",
)
with gr.Column(scale=0.6, visible=True) as video_upload:
# with gr.Column(visible=True) as video_upload:
output = gr.Video(interactive=False, include_audio=True, elem_id="输出的视频") #.style(height=360)
# with gr.Column(elem_id="image", scale=0.5) as img_part:
# with gr.Tab("Video", elem_id='video_tab'):
# with gr.Tab("Image", elem_id='image_tab'):
# up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload").style(height=360)
# upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
# clear = gr.Button("Restart")
with gr.Row():
with gr.Column(scale=1.0, min_width=0):
run = gr.Button("💭Run")
# with gr.Column(scale=0.5, min_width=0):
# clear = gr.Button("🔄Clear️")
run.click(gen_video, [text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step], [output])
demo.launch(debug=False, share=True)
# demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)