Spaces:
Running
Running
import gradio as gr | |
import os | |
import torch | |
import argparse | |
import torchvision | |
from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler, | |
EulerDiscreteScheduler, DPMSolverMultistepScheduler, | |
HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, | |
DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler) | |
from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler | |
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder | |
from omegaconf import OmegaConf | |
from transformers import T5EncoderModel, T5Tokenizer | |
import os, sys | |
sys.path.append(os.path.split(sys.path[0])[0]) | |
from sample.pipeline_latte import LattePipeline | |
from models import get_models | |
# import imageio | |
from torchvision.utils import save_image | |
import spaces | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default="./configs/t2x/t2v_sample.yaml") | |
args = parser.parse_args() | |
args = OmegaConf.load(args.config) | |
torch.set_grad_enabled(False) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
transformer_model = get_models(args).to(device, dtype=torch.float16) | |
# state_dict = find_model(args.ckpt) | |
# msg, unexp = transformer_model.load_state_dict(state_dict, strict=False) | |
if args.enable_vae_temporal_decoder: | |
vae = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device) | |
else: | |
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", torch_dtype=torch.float16).to(device) | |
tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") | |
text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) | |
# set eval mode | |
transformer_model.eval() | |
vae.eval() | |
text_encoder.eval() | |
def gen_video(text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step): | |
torch.manual_seed(seed) | |
if sample_method == 'DDIM': | |
scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type, | |
clip_sample=False) | |
elif sample_method == 'EulerDiscrete': | |
scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
elif sample_method == 'DDPM': | |
scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type, | |
clip_sample=False) | |
elif sample_method == 'DPMSolverMultistep': | |
scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
elif sample_method == 'DPMSolverSinglestep': | |
scheduler = DPMSolverSinglestepScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
elif sample_method == 'PNDM': | |
scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
elif sample_method == 'HeunDiscrete': | |
scheduler = HeunDiscreteScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
elif sample_method == 'EulerAncestralDiscrete': | |
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
elif sample_method == 'DEISMultistep': | |
scheduler = DEISMultistepScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
elif sample_method == 'KDPM2AncestralDiscrete': | |
scheduler = KDPM2AncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
videogen_pipeline = LattePipeline(vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
scheduler=scheduler, | |
transformer=transformer_model).to(device) | |
# videogen_pipeline.enable_xformers_memory_efficient_attention() | |
videos = videogen_pipeline(text_input, | |
video_length=video_length, | |
height=height, | |
width=width, | |
num_inference_steps=diffusion_step, | |
guidance_scale=scfg_scale, | |
enable_temporal_attentions=args.enable_temporal_attentions, | |
num_images_per_prompt=1, | |
mask_feature=True, | |
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder | |
).video | |
save_path = args.save_img_path + 'temp' + '.mp4' | |
torchvision.io.write_video(save_path, videos[0], fps=8) | |
return save_path | |
if not os.path.exists(args.save_img_path): | |
os.makedirs(args.save_img_path) | |
intro = """ | |
<div style="display: flex;align-items: center;justify-content: center"> | |
<h1 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Latte: Latent Diffusion Transformer for Video Generation</h1> | |
</div> | |
""" | |
with gr.Blocks() as demo: | |
# gr.HTML(intro) | |
# with gr.Accordion("README", open=False): | |
# gr.HTML( | |
# """ | |
# <p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block"> | |
# <a href="https://maxin-cn.github.io/latte_project/" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2401.03048" target="_blank">paper</a> | |
# </p> | |
# We will continue update Latte. | |
# """ | |
# ) | |
gr.Markdown("<font color=red size=10><center>Latte: Latent Diffusion Transformer for Video Generation</center></font>") | |
gr.Markdown( | |
"""<div style="display: flex;align-items: center;justify-content: center"> | |
<h2 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Latte supports both T2I and T2V, and will be continuously updated, so stay tuned!</h2></div> | |
""" | |
) | |
gr.Markdown( | |
"""<div style="display: flex;align-items: center;justify-content: center"> | |
[<a href="https://arxiv.org/abs/2401.03048">Arxiv Report</a>] | [<a href="https://maxin-cn.github.io/latte_project/">Project Page</a>] | [<a href="https://github.com/Vchitect/Latte">Github</a>]</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(visible=True) as input_raws: | |
with gr.Row(): | |
with gr.Column(scale=1.0): | |
# text_input = gr.Textbox(show_label=True, interactive=True, label="Text prompt").style(container=False) | |
text_input = gr.Textbox(show_label=True, interactive=True, label="Prompt") | |
# with gr.Row(): | |
# with gr.Column(scale=0.5): | |
# image_input = gr.Image(show_label=True, interactive=True, label="Reference image").style(container=False) | |
# with gr.Column(scale=0.5): | |
# preframe_input = gr.Image(show_label=True, interactive=True, label="First frame").style(container=False) | |
with gr.Row(): | |
with gr.Column(scale=0.5): | |
sample_method = gr.Dropdown(choices=["DDIM", "EulerDiscrete", "PNDM"], label="Sample Method", value="DDIM") | |
# with gr.Row(): | |
# with gr.Column(scale=1.0): | |
# video_length = gr.Slider( | |
# minimum=1, | |
# maximum=24, | |
# value=1, | |
# step=1, | |
# interactive=True, | |
# label="Video Length (1 for T2I and 16 for T2V)", | |
# ) | |
with gr.Column(scale=0.5): | |
video_length = gr.Dropdown(choices=[1, 16], label="Video Length (1 for T2I and 16 for T2V)", value=16) | |
with gr.Row(): | |
with gr.Column(scale=1.0): | |
scfg_scale = gr.Slider( | |
minimum=1, | |
maximum=50, | |
value=7.5, | |
step=0.1, | |
interactive=True, | |
label="Guidence Scale", | |
) | |
with gr.Row(): | |
with gr.Column(scale=1.0): | |
seed = gr.Slider( | |
minimum=1, | |
maximum=2147483647, | |
value=100, | |
step=1, | |
interactive=True, | |
label="Seed", | |
) | |
with gr.Row(): | |
with gr.Column(scale=0.5): | |
height = gr.Slider( | |
minimum=256, | |
maximum=768, | |
value=512, | |
step=16, | |
interactive=False, | |
label="Height", | |
) | |
# with gr.Row(): | |
with gr.Column(scale=0.5): | |
width = gr.Slider( | |
minimum=256, | |
maximum=768, | |
value=512, | |
step=16, | |
interactive=False, | |
label="Width", | |
) | |
with gr.Row(): | |
with gr.Column(scale=1.0): | |
diffusion_step = gr.Slider( | |
minimum=20, | |
maximum=250, | |
value=50, | |
step=1, | |
interactive=True, | |
label="Sampling Step", | |
) | |
with gr.Column(scale=0.6, visible=True) as video_upload: | |
# with gr.Column(visible=True) as video_upload: | |
output = gr.Video(interactive=False, include_audio=True, elem_id="输出的视频") #.style(height=360) | |
# with gr.Column(elem_id="image", scale=0.5) as img_part: | |
# with gr.Tab("Video", elem_id='video_tab'): | |
# with gr.Tab("Image", elem_id='image_tab'): | |
# up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload").style(height=360) | |
# upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") | |
# clear = gr.Button("Restart") | |
with gr.Row(): | |
with gr.Column(scale=1.0, min_width=0): | |
run = gr.Button("💭Run") | |
# with gr.Column(scale=0.5, min_width=0): | |
# clear = gr.Button("🔄Clear️") | |
run.click(gen_video, [text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step], [output]) | |
demo.launch(debug=False, share=True) | |
# demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True) | |