File size: 5,110 Bytes
b3f324b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import argparse
import sys
import os
import random
import imageio
import torch
from diffusers import PNDMScheduler
from huggingface_hub import hf_hub_download
from torchvision.utils import save_image
from diffusers.models import AutoencoderKL
from datetime import datetime
from typing import List, Union
import gradio as gr
import numpy as np
from gradio.components import Textbox, Video, Image
from transformers import T5Tokenizer, T5EncoderModel
from opensora.models.ae import ae_stride_config, getae, getae_wrapper
from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper
from opensora.models.diffusion.latte.modeling_latte import LatteT2V
from opensora.sample.pipeline_videogen import VideoGenPipeline
from opensora.serve.gradio_utils import block_css, title_markdown, randomize_seed_fn, set_env, examples, DESCRIPTION
@torch.inference_mode()
def generate_img(prompt, sample_steps, scale, seed=0, randomize_seed=False, force_images=False):
seed = int(randomize_seed_fn(seed, randomize_seed))
set_env(seed)
video_length = transformer_model.config.video_length if not force_images else 1
height, width = int(args.version.split('x')[1]), int(args.version.split('x')[2])
num_frames = 1 if video_length == 1 else int(args.version.split('x')[0])
videos = videogen_pipeline(prompt,
video_length=video_length,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
enable_temporal_attentions=not force_images,
num_images_per_prompt=1,
mask_feature=True,
).video
torch.cuda.empty_cache()
videos = videos[0]
tmp_save_path = 'tmp.mp4'
imageio.mimwrite(tmp_save_path, videos, fps=24, quality=9) # highest quality is 10, lowest is 0
display_model_info = f"Video size: {num_frames}×{height}×{width}, \nSampling Step: {sample_steps}, \nGuidance Scale: {scale}"
return tmp_save_path, prompt, display_model_info, seed
if __name__ == '__main__':
args = type('args', (), {
'ae': 'CausalVAEModel_4x8x8',
'force_images': False,
'model_path': 'LanguageBind/Open-Sora-Plan-v1.0.0',
'text_encoder_name': 'DeepFloyd/t5-v1_1-xxl',
'version': '65x512x512'
})
device = torch.device('cuda:0')
# Load model:
transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, torch_dtype=torch.float16, cache_dir='cache_dir').to(device)
vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir').to(device, dtype=torch.float16)
vae.vae.enable_tiling()
image_size = int(args.version.split('x')[1])
latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2])
vae.latent_size = latent_size
transformer_model.force_images = args.force_images
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir="cache_dir")
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir",
torch_dtype=torch.float16).to(device)
# set eval mode
transformer_model.eval()
vae.eval()
text_encoder.eval()
scheduler = PNDMScheduler()
videogen_pipeline = VideoGenPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer_model).to(device=device)
demo = gr.Interface(
fn=generate_img,
inputs=[Textbox(label="",
placeholder="Please enter your prompt. \n"),
gr.Slider(
label='Sample Steps',
minimum=1,
maximum=500,
value=50,
step=10
),
gr.Slider(
label='Guidance Scale',
minimum=0.1,
maximum=30.0,
value=10.0,
step=0.1
),
gr.Slider(
label="Seed",
minimum=0,
maximum=203279,
step=1,
value=0,
),
gr.Checkbox(label="Randomize seed", value=True),
gr.Checkbox(label="Generate image (1 frame video)", value=False),
],
outputs=[Video(label="Vid", width=512, height=512),
Textbox(label="input prompt"),
Textbox(label="model info"),
gr.Slider(label='seed')],
title=title_markdown, description=DESCRIPTION, theme=gr.themes.Default(), css=block_css,
examples=examples,
)
demo.launch() |