multimodalart HF staff commited on
Commit
8177071
1 Parent(s): 07d8f89

PR with param fix, steps and DPM Solver

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -6,7 +6,7 @@ import time
6
  import gradio as gr
7
  import numpy as np
8
  import torch
9
- from diffusers import CogVideoXPipeline
10
  from datetime import datetime, timedelta
11
  from openai import OpenAI
12
  import spaces
@@ -18,6 +18,7 @@ import PIL
18
  dtype = torch.float16
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=dtype).to(device)
 
21
 
22
  sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
23
 
@@ -88,7 +89,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
88
  return prompt
89
 
90
 
91
- @spaces.GPU(duration=200)
92
  def infer(
93
  prompt: str,
94
  num_inference_steps: int,
@@ -171,10 +172,9 @@ with gr.Blocks() as demo:
171
 
172
  with gr.Column():
173
  gr.Markdown("**Optional Parameters** (default values are recommended)<br>"
174
- "Reduce the inference steps (such as 25) for faster generation, but this may degrade the quality of video.<br>"
175
- "50 steps are recommended for most cases. will cause 150 seconds for inference.<br>")
176
  with gr.Row():
177
- num_inference_steps = gr.Number(label="Inference Steps", value=50)
178
  guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
179
  generate_button = gr.Button("🎬 Generate Video")
180
 
 
6
  import gradio as gr
7
  import numpy as np
8
  import torch
9
+ from diffusers import CogVideoXPipeline, CogVideoXDPMScheduler
10
  from datetime import datetime, timedelta
11
  from openai import OpenAI
12
  import spaces
 
18
  dtype = torch.float16
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=dtype).to(device)
21
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
22
 
23
  sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
24
 
 
89
  return prompt
90
 
91
 
92
+ @spaces.GPU(duration=120)
93
  def infer(
94
  prompt: str,
95
  num_inference_steps: int,
 
172
 
173
  with gr.Column():
174
  gr.Markdown("**Optional Parameters** (default values are recommended)<br>"
175
+ "24 steps are recommended for most cases for a trade-off between speed and quality<br>")
 
176
  with gr.Row():
177
+ num_inference_steps = gr.Slider(label="Inference Steps", value=24, minimum=1, maximum=24)
178
  guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
179
  generate_button = gr.Button("🎬 Generate Video")
180