Spaces:
Runtime error
Runtime error
| import gc | |
| import os | |
| import torch | |
| import spaces | |
| import gradio as gr | |
| from diffusers import LattePipeline | |
| from transformers import T5EncoderModel, BitsAndBytesConfig | |
| import imageio | |
| from torchvision.utils import save_image | |
| def flush(): | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def bytes_to_giga_bytes(bytes): | |
| return bytes / 1024 / 1024 / 1024 | |
| def initialize_pipeline(): | |
| model_id = "maxin-cn/Latte-1" | |
| text_encoder = T5EncoderModel.from_pretrained( | |
| model_id, | |
| subfolder="text_encoder", | |
| quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16), | |
| device_map="auto", | |
| ) | |
| pipe = LattePipeline.from_pretrained( | |
| model_id, | |
| text_encoder=text_encoder, | |
| transformer=None, | |
| device_map="balanced", | |
| ) | |
| return pipe, text_encoder | |
| def generate_video( | |
| prompt: str, | |
| negative_prompt: str = "", | |
| video_length: int = 16, | |
| num_inference_steps: int = 50, | |
| progress=gr.Progress() | |
| ): | |
| # Set random seed for reproducibility | |
| torch.manual_seed(0) | |
| # Initialize the pipeline | |
| progress(0, desc="Initializing pipeline...") | |
| pipe, text_encoder = initialize_pipeline() | |
| # Generate prompt embeddings | |
| progress(0.2, desc="Encoding prompt...") | |
| with torch.no_grad(): | |
| prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( | |
| prompt, | |
| negative_prompt=negative_prompt | |
| ) | |
| # Clean up first pipeline | |
| progress(0.3, desc="Cleaning up...") | |
| del text_encoder | |
| del pipe | |
| flush() | |
| # Initialize the second pipeline | |
| progress(0.4, desc="Initializing generation pipeline...") | |
| pipe = LattePipeline.from_pretrained( | |
| "maxin-cn/Latte-1", | |
| text_encoder=None, | |
| torch_dtype=torch.float16, | |
| ).to("cuda") | |
| # Generate video | |
| progress(0.5, desc="Generating video...") | |
| videos = pipe( | |
| video_length=video_length, | |
| num_inference_steps=num_inference_steps, | |
| negative_prompt=None, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| output_type="pt", | |
| ).frames.cpu() | |
| progress(0.8, desc="Post-processing...") | |
| # Convert to video format | |
| videos = (videos.clamp(0, 1) * 255).to(dtype=torch.uint8) | |
| # Save temporary file | |
| temp_output = "temp_output.mp4" | |
| imageio.mimwrite( | |
| temp_output, | |
| videos[0].permute(0, 2, 3, 1), | |
| fps=8, | |
| quality=5 | |
| ) | |
| # Clean up | |
| progress(0.9, desc="Cleaning up...") | |
| del pipe | |
| flush() | |
| progress(1.0, desc="Done!") | |
| return temp_output | |
| def create_demo(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # Latte Video Generation | |
| Generate short videos using the Latte-1 model. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value="a cat wearing sunglasses and working as a lifeguard at pool.", | |
| info="Describe what you want to generate" | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value="", | |
| info="What you don't want to see in the generation" | |
| ) | |
| video_length = gr.Slider( | |
| minimum=8, | |
| maximum=32, | |
| step=8, | |
| value=16, | |
| label="Video Length (frames)" | |
| ) | |
| steps = gr.Slider( | |
| minimum=20, | |
| maximum=100, | |
| step=10, | |
| value=50, | |
| label="Number of Inference Steps" | |
| ) | |
| generate_btn = gr.Button("Generate Video") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Generated Video") | |
| generate_btn.click( | |
| fn=generate_video, | |
| inputs=[prompt, negative_prompt, video_length, steps], | |
| outputs=output_video | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.queue() | |
| demo.launch(share=False) | |