Spaces:
Sleeping
Sleeping
import gc | |
import os | |
import torch | |
import spaces | |
import gradio as gr | |
from diffusers import LattePipeline | |
from transformers import T5EncoderModel, BitsAndBytesConfig | |
import imageio | |
from torchvision.utils import save_image | |
def flush(): | |
gc.collect() | |
torch.cuda.empty_cache() | |
def bytes_to_giga_bytes(bytes): | |
return bytes / 1024 / 1024 / 1024 | |
def initialize_pipeline(): | |
model_id = "maxin-cn/Latte-1" | |
text_encoder = T5EncoderModel.from_pretrained( | |
model_id, | |
subfolder="text_encoder", | |
quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16), | |
device_map="auto", | |
) | |
pipe = LattePipeline.from_pretrained( | |
model_id, | |
text_encoder=text_encoder, | |
transformer=None, | |
device_map="balanced", | |
) | |
return pipe, text_encoder | |
def generate_video( | |
prompt: str, | |
negative_prompt: str = "", | |
video_length: int = 16, | |
num_inference_steps: int = 50, | |
progress=gr.Progress() | |
): | |
# Set random seed for reproducibility | |
torch.manual_seed(0) | |
# Initialize the pipeline | |
progress(0, desc="Initializing pipeline...") | |
pipe, text_encoder = initialize_pipeline() | |
# Generate prompt embeddings | |
progress(0.2, desc="Encoding prompt...") | |
with torch.no_grad(): | |
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( | |
prompt, | |
negative_prompt=negative_prompt | |
) | |
# Clean up first pipeline | |
progress(0.3, desc="Cleaning up...") | |
del text_encoder | |
del pipe | |
flush() | |
# Initialize the second pipeline | |
progress(0.4, desc="Initializing generation pipeline...") | |
pipe = LattePipeline.from_pretrained( | |
"maxin-cn/Latte-1", | |
text_encoder=None, | |
torch_dtype=torch.float16, | |
).to("cuda") | |
# Generate video | |
progress(0.5, desc="Generating video...") | |
videos = pipe( | |
video_length=video_length, | |
num_inference_steps=num_inference_steps, | |
negative_prompt=None, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
output_type="pt", | |
).frames.cpu() | |
progress(0.8, desc="Post-processing...") | |
# Convert to video format | |
videos = (videos.clamp(0, 1) * 255).to(dtype=torch.uint8) | |
# Save temporary file | |
temp_output = "temp_output.mp4" | |
imageio.mimwrite( | |
temp_output, | |
videos[0].permute(0, 2, 3, 1), | |
fps=8, | |
quality=5 | |
) | |
# Clean up | |
progress(0.9, desc="Cleaning up...") | |
del pipe | |
flush() | |
progress(1.0, desc="Done!") | |
return temp_output | |
def create_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# Latte Video Generation | |
Generate short videos using the Latte-1 model. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
value="a cat wearing sunglasses and working as a lifeguard at pool.", | |
info="Describe what you want to generate" | |
) | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="", | |
info="What you don't want to see in the generation" | |
) | |
video_length = gr.Slider( | |
minimum=8, | |
maximum=32, | |
step=8, | |
value=16, | |
label="Video Length (frames)" | |
) | |
steps = gr.Slider( | |
minimum=20, | |
maximum=100, | |
step=10, | |
value=50, | |
label="Number of Inference Steps" | |
) | |
generate_btn = gr.Button("Generate Video") | |
with gr.Column(): | |
output_video = gr.Video(label="Generated Video") | |
generate_btn.click( | |
fn=generate_video, | |
inputs=[prompt, negative_prompt, video_length, steps], | |
outputs=output_video | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.queue() | |
demo.launch(share=False) | |