Spaces:
Runtime error
Runtime error
| """ | |
| This script is used to create a Streamlit web application for generating videos using the CogVideoX model. | |
| Run the script using Streamlit: | |
| $ export OPENAI_API_KEY=your OpenAI Key or ZhiupAI Key | |
| $ export OPENAI_BASE_URL=https://open.bigmodel.cn/api/paas/v4/ # using with ZhipuAI, Not using this when using OpenAI | |
| $ streamlit run web_demo.py | |
| """ | |
| import base64 | |
| import json | |
| import os | |
| import time | |
| from datetime import datetime | |
| from typing import List | |
| import imageio | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from convert_demo import convert_prompt | |
| from diffusers import CogVideoXPipeline | |
| model_path: str = "THUDM/CogVideoX-2b" | |
| # Load the model at the start | |
| def load_model(model_path: str, dtype: torch.dtype, device: str) -> CogVideoXPipeline: | |
| """ | |
| Load the CogVideoX model. | |
| Args: | |
| - model_path (str): Path to the model. | |
| - dtype (torch.dtype): Data type for model. | |
| - device (str): Device to load the model on. | |
| Returns: | |
| - CogVideoXPipeline: Loaded model pipeline. | |
| """ | |
| return CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device) | |
| # Define a function to generate video based on the provided prompt and model path | |
| def generate_video( | |
| pipe: CogVideoXPipeline, | |
| prompt: str, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 6.0, | |
| num_videos_per_prompt: int = 1, | |
| device: str = "cuda", | |
| dtype: torch.dtype = torch.float16, | |
| ) -> List[np.ndarray]: | |
| """ | |
| Generate a video based on the provided prompt and model path. | |
| Args: | |
| - pipe (CogVideoXPipeline): The pipeline for generating videos. | |
| - prompt (str): Text prompt for video generation. | |
| - num_inference_steps (int): Number of inference steps. | |
| - guidance_scale (float): Guidance scale for generation. | |
| - num_videos_per_prompt (int): Number of videos to generate per prompt. | |
| - device (str): Device to run the generation on. | |
| - dtype (torch.dtype): Data type for the model. | |
| Returns: | |
| - List[np.ndarray]: Generated video frames. | |
| """ | |
| prompt_embeds, _ = pipe.encode_prompt( | |
| prompt=prompt, | |
| negative_prompt=None, | |
| do_classifier_free_guidance=True, | |
| num_videos_per_prompt=num_videos_per_prompt, | |
| max_sequence_length=226, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| # Generate video | |
| video = pipe( | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=torch.zeros_like(prompt_embeds), | |
| ).frames[0] | |
| return video | |
| def save_video(video: List[np.ndarray], path: str, fps: int = 8) -> None: | |
| """ | |
| Save the generated video to a file. | |
| Args: | |
| - video (List[np.ndarray]): Video frames. | |
| - path (str): Path to save the video. | |
| - fps (int): Frames per second for the video. | |
| """ | |
| # Remove the first frame | |
| video = video[1:] | |
| writer = imageio.get_writer(path, fps=fps, codec="libx264") | |
| for frame in video: | |
| np_frame = np.array(frame) | |
| writer.append_data(np_frame) | |
| writer.close() | |
| def save_metadata( | |
| prompt: str, | |
| converted_prompt: str, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| num_videos_per_prompt: int, | |
| path: str, | |
| ) -> None: | |
| """ | |
| Save metadata to a JSON file. | |
| Args: | |
| - prompt (str): Original prompt. | |
| - converted_prompt (str): Converted prompt. | |
| - num_inference_steps (int): Number of inference steps. | |
| - guidance_scale (float): Guidance scale. | |
| - num_videos_per_prompt (int): Number of videos per prompt. | |
| - path (str): Path to save the metadata. | |
| """ | |
| metadata = { | |
| "prompt": prompt, | |
| "converted_prompt": converted_prompt, | |
| "num_inference_steps": num_inference_steps, | |
| "guidance_scale": guidance_scale, | |
| "num_videos_per_prompt": num_videos_per_prompt, | |
| } | |
| with open(path, "w") as f: | |
| json.dump(metadata, f, indent=4) | |
| def main() -> None: | |
| """ | |
| Main function to run the Streamlit web application. | |
| """ | |
| st.set_page_config(page_title="CogVideoX-Demo", page_icon="🎥", layout="wide") | |
| st.write("# CogVideoX 🎥") | |
| dtype: torch.dtype = torch.float16 | |
| device: str = "cuda" | |
| global pipe | |
| pipe = load_model(model_path, dtype, device) | |
| with st.sidebar: | |
| st.info("It will take some time to generate a video (~90 seconds per videos in 50 steps).", icon="ℹ️") | |
| num_inference_steps: int = st.number_input("Inference Steps", min_value=1, max_value=100, value=50) | |
| guidance_scale: float = st.number_input("Guidance Scale", min_value=0.0, max_value=20.0, value=6.0) | |
| num_videos_per_prompt: int = st.number_input("Videos per Prompt", min_value=1, max_value=10, value=1) | |
| share_links_container = st.empty() | |
| prompt: str = st.chat_input("Prompt") | |
| if prompt: | |
| # Not Necessary, Suggestions | |
| with st.spinner("Refining prompts..."): | |
| converted_prompt = convert_prompt(prompt=prompt, retry_times=1) | |
| if converted_prompt is None: | |
| st.error("Failed to Refining the prompt, Using origin one.") | |
| st.info(f"**Origin prompt:** \n{prompt} \n \n**Convert prompt:** \n{converted_prompt}") | |
| torch.cuda.empty_cache() | |
| with st.spinner("Generating Video..."): | |
| start_time = time.time() | |
| video_paths = [] | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_dir = f"./output/{timestamp}" | |
| os.makedirs(output_dir, exist_ok=True) | |
| metadata_path = os.path.join(output_dir, "config.json") | |
| save_metadata( | |
| prompt, converted_prompt, num_inference_steps, guidance_scale, num_videos_per_prompt, metadata_path | |
| ) | |
| for i in range(num_videos_per_prompt): | |
| video_path = os.path.join(output_dir, f"output_{i + 1}.mp4") | |
| video = generate_video( | |
| pipe, converted_prompt or prompt, num_inference_steps, guidance_scale, 1, device, dtype | |
| ) | |
| save_video(video, video_path, fps=8) | |
| video_paths.append(video_path) | |
| with open(video_path, "rb") as video_file: | |
| video_bytes: bytes = video_file.read() | |
| st.video(video_bytes, autoplay=True, loop=True, format="video/mp4") | |
| torch.cuda.empty_cache() | |
| used_time: float = time.time() - start_time | |
| st.success(f"Videos generated in {used_time:.2f} seconds.") | |
| # Create download links in the sidebar | |
| with share_links_container: | |
| st.sidebar.write("### Download Links:") | |
| for video_path in video_paths: | |
| video_name = os.path.basename(video_path) | |
| with open(video_path, "rb") as f: | |
| video_bytes: bytes = f.read() | |
| b64_video = base64.b64encode(video_bytes).decode() | |
| href = f'<a href="data:video/mp4;base64,{b64_video}" download="{video_name}">Download {video_name}</a>' | |
| st.sidebar.markdown(href, unsafe_allow_html=True) | |
| if __name__ == "__main__": | |
| main() | |