Ji4chenLi commited on
Commit
7a34add
1 Parent(s): 70cdee7

support higher version of gradio

Browse files
Files changed (3) hide show
  1. app.py +132 -221
  2. requirements.txt +0 -1
  3. style.css +0 -16
app.py CHANGED
@@ -1,18 +1,16 @@
1
  import os
2
  import uuid
3
- import gradio as gr
4
- import numpy as np
5
- import random
6
- import time
7
  from omegaconf import OmegaConf
8
-
9
  import spaces
10
 
 
 
 
11
  import torch
12
  import torchvision
13
-
14
- from concurrent.futures import ThreadPoolExecutor
15
- import uuid
16
 
17
  from utils.lora import collapse_lora, monkeypatch_remove_lora
18
  from utils.lora_handler import LoraHandler
@@ -21,11 +19,6 @@ from utils.utils import instantiate_from_config
21
  from scheduler.t2v_turbo_scheduler import T2VTurboScheduler
22
  from pipeline.t2v_turbo_vc2_pipeline import T2VTurboVC2Pipeline
23
 
24
-
25
- device = "cuda" if torch.cuda.is_available() else "cpu"
26
- MAX_SEED = np.iinfo(np.int32).max
27
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
28
-
29
  DESCRIPTION = """# T2V-Turbo 🚀
30
  We provide T2V-Turbo (VC2) distilled from [VideoCrafter2](https://ailab-cvc.github.io/videocrafter2/) with the reward feedback from [HPSv2.1](https://github.com/tgxs002/HPSv2/tree/master) and [InternVid2 Stage 2 Model](https://huggingface.co/OpenGVLab/InternVideo2-Stage2_1B-224p-f4).
31
 
@@ -37,9 +30,84 @@ elif hasattr(torch, "xpu") and torch.xpu.is_available():
37
  DESCRIPTION += "\n<p>Running on XPU 🤓</p>"
38
  else:
39
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
 
40
 
41
 
42
- if torch.cuda.is_available():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  config = OmegaConf.load("configs/inference_t2v_512_v2.0.yaml")
44
  model_config = config.pop("model", OmegaConf.create())
45
  pretrained_t2v = instantiate_from_config(model_config)
@@ -72,8 +140,6 @@ if torch.cuda.is_available():
72
  collapse_lora(unet, lora_manager.unet_replace_modules)
73
  monkeypatch_remove_lora(unet)
74
 
75
- torch.save(unet.state_dict(), "checkpoints/merged_unet.pt")
76
-
77
  pretrained_t2v.model.diffusion_model = unet
78
  scheduler = T2VTurboScheduler(
79
  linear_start=model_config["params"]["linear_start"],
@@ -82,215 +148,60 @@ if torch.cuda.is_available():
82
  pipeline = T2VTurboVC2Pipeline(pretrained_t2v, scheduler, model_config)
83
 
84
  pipeline.to(device)
85
- else:
86
- assert False
87
-
88
-
89
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
90
- if randomize_seed:
91
- seed = random.randint(0, MAX_SEED)
92
- return seed
93
-
94
-
95
- def save_video(
96
- vid_tensor, profile: gr.OAuthProfile | None, metadata: dict, root_path="./", fps=16
97
- ):
98
- unique_name = str(uuid.uuid4()) + ".mp4"
99
- prefix = ""
100
- for k, v in metadata.items():
101
- prefix += f"{k}={v}_"
102
- unique_name = prefix + unique_name
103
- unique_name = os.path.join(root_path, unique_name)
104
-
105
- video = vid_tensor.detach().cpu()
106
- video = torch.clamp(video.float(), -1.0, 1.0)
107
- video = video.permute(1, 0, 2, 3) # t,c,h,w
108
- video = (video + 1.0) / 2.0
109
- video = (video * 255).to(torch.uint8).permute(0, 2, 3, 1)
110
-
111
- torchvision.io.write_video(
112
- unique_name, video, fps=fps, video_codec="h264", options={"crf": "10"}
113
- )
114
- return unique_name
115
-
116
-
117
- def save_videos(
118
- video_array, profile: gr.OAuthProfile | None, metadata: dict, fps: int = 16
119
- ):
120
- paths = []
121
- root_path = "./videos/"
122
- os.makedirs(root_path, exist_ok=True)
123
- with ThreadPoolExecutor() as executor:
124
- paths = list(
125
- executor.map(
126
- save_video,
127
- video_array,
128
- [profile] * len(video_array),
129
- [metadata] * len(video_array),
130
- [root_path] * len(video_array),
131
- [fps] * len(video_array),
132
- )
133
- )
134
- return paths[0]
135
-
136
-
137
- @spaces.GPU(duration=60)
138
- def generate(
139
- prompt: str,
140
- seed: int = 0,
141
- guidance_scale: float = 7.5,
142
- num_inference_steps: int = 4,
143
- num_frames: int = 16,
144
- fps: int = 16,
145
- randomize_seed: bool = False,
146
- param_dtype="torch.float16",
147
- progress=gr.Progress(track_tqdm=True),
148
- profile: gr.OAuthProfile | None = None,
149
- ):
150
- seed = randomize_seed_fn(seed, randomize_seed)
151
- torch.manual_seed(seed)
152
- pipeline.to(
153
- torch_device=device,
154
- torch_dtype=torch.float16 if param_dtype == "torch.float16" else torch.float32,
155
- )
156
- start_time = time.time()
157
 
158
- result = pipeline(
159
- prompt=prompt,
160
- frames=num_frames,
161
- fps=fps,
162
- guidance_scale=guidance_scale,
163
- num_inference_steps=num_inference_steps,
164
- num_videos_per_prompt=1,
165
- )
166
- paths = save_videos(
167
- result,
168
- profile,
169
- metadata={
170
- "prompt": prompt,
171
- "seed": seed,
172
- "guidance_scale": guidance_scale,
173
- "num_inference_steps": num_inference_steps,
174
- },
175
- fps=fps,
176
- )
177
- print(time.time() - start_time)
178
- return paths, seed
179
-
180
- examples = [
181
- "An astronaut riding a horse.",
182
- "Darth vader surfing in waves.",
183
- "Robot dancing in times square.",
184
- "Clown fish swimming through the coral reef.",
185
- "Pikachu snowboarding.",
186
- "With the style of van gogh, A young couple dances under the moonlight by the lake.",
187
- "A young woman with glasses is jogging in the park wearing a pink headband.",
188
- "Impressionist style, a yellow rubber duck floating on the wave on the sunset",
189
- "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
190
- "With the style of low-poly game art, A majestic, white horse gallops gracefully across a moonlit beach.",
191
- ]
192
-
193
-
194
- if torch.cuda.is_available():
195
- power_device = "GPU"
196
- else:
197
- power_device = "CPU"
198
-
199
- with gr.Blocks(css="style.css") as demo:
200
-
201
- with gr.Column(elem_id="col-container"):
202
- gr.Markdown(DESCRIPTION)
203
-
204
- with gr.Row():
205
- prompt = gr.Text(
206
- label="Prompt",
207
- show_label=False,
208
- max_lines=1,
209
- placeholder="Enter your prompt",
210
- container=False,
211
- )
212
- run_button = gr.Button("Run", scale=0)
213
- result_video = gr.Video(
214
- label="Generated Video", interactive=False, autoplay=True
215
- )
216
-
217
- with gr.Accordion("Advanced Settings", open=False):
218
-
219
- seed = gr.Slider(
220
  label="Seed",
221
  minimum=0,
222
  maximum=MAX_SEED,
223
  step=1,
224
  value=0,
225
  randomize=True,
226
- )
227
- randomize_seed = gr.Checkbox(label="Randomize seed across runs", value=True)
228
- dtype_choices = ["torch.float16", "torch.float32"]
229
- param_dtype = gr.Radio(
230
- dtype_choices,
231
- label="torch.dtype",
232
- value=dtype_choices[0],
233
- interactive=True,
234
- info="To save GPU memory, use torch.float16. For better quality, use torch.float32.",
235
- )
236
-
237
- with gr.Row():
238
- guidance_scale = gr.Slider(
239
- label="Guidance scale for base",
240
- minimum=2,
241
- maximum=14,
242
- step=0.1,
243
- value=7.5,
244
- )
245
- num_inference_steps = gr.Slider(
246
- label="Number of inference steps for base",
247
- minimum=1,
248
- maximum=8,
249
- step=1,
250
- value=4,
251
- )
252
- with gr.Row():
253
- num_frames = gr.Slider(
254
- label="Number of Video Frames",
255
- minimum=16,
256
- maximum=48,
257
- step=8,
258
- value=16,
259
- )
260
- fps = gr.Slider(
261
- label="FPS",
262
- minimum=8,
263
- maximum=32,
264
- step=4,
265
- value=16,
266
- )
267
-
268
- gr.Examples(
269
- examples=examples,
270
- inputs=prompt,
271
- outputs=result_video,
272
- fn=generate,
273
- cache_examples=CACHE_EXAMPLES,
274
- )
275
-
276
- gr.on(
277
- triggers=[
278
- prompt.submit,
279
- run_button.click,
280
- ],
281
- fn=generate,
282
- inputs=[
283
- prompt,
284
- seed,
285
- guidance_scale,
286
- num_inference_steps,
287
- num_frames,
288
- fps,
289
- randomize_seed,
290
- param_dtype,
291
- ],
292
- outputs=[result_video, seed],
293
- api_name="run",
294
- )
295
-
296
- demo.queue().launch()
 
1
  import os
2
  import uuid
 
 
 
 
3
  from omegaconf import OmegaConf
 
4
  import spaces
5
 
6
+ import random
7
+
8
+ import imageio
9
  import torch
10
  import torchvision
11
+ import gradio as gr
12
+ import numpy as np
13
+ from gradio.components import Textbox, Video
14
 
15
  from utils.lora import collapse_lora, monkeypatch_remove_lora
16
  from utils.lora_handler import LoraHandler
 
19
  from scheduler.t2v_turbo_scheduler import T2VTurboScheduler
20
  from pipeline.t2v_turbo_vc2_pipeline import T2VTurboVC2Pipeline
21
 
 
 
 
 
 
22
  DESCRIPTION = """# T2V-Turbo 🚀
23
  We provide T2V-Turbo (VC2) distilled from [VideoCrafter2](https://ailab-cvc.github.io/videocrafter2/) with the reward feedback from [HPSv2.1](https://github.com/tgxs002/HPSv2/tree/master) and [InternVid2 Stage 2 Model](https://huggingface.co/OpenGVLab/InternVideo2-Stage2_1B-224p-f4).
24
 
 
30
  DESCRIPTION += "\n<p>Running on XPU 🤓</p>"
31
  else:
32
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
33
+ MAX_SEED = np.iinfo(np.int32).max
34
 
35
 
36
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
37
+ if randomize_seed:
38
+ seed = random.randint(0, MAX_SEED)
39
+ return seed
40
+
41
+
42
+ def save_video(video_array, video_save_path, fps: int = 16):
43
+ video = video_array.detach().cpu()
44
+ video = torch.clamp(video.float(), -1.0, 1.0)
45
+ video = video.permute(1, 0, 2, 3) # t,c,h,w
46
+ video = (video + 1.0) / 2.0
47
+ video = (video * 255).to(torch.uint8).permute(0, 2, 3, 1)
48
+
49
+ torchvision.io.write_video(
50
+ video_save_path, video, fps=fps, video_codec="h264", options={"crf": "10"}
51
+ )
52
+
53
+ example_txt = [
54
+ "An astronaut riding a horse.",
55
+ "Darth vader surfing in waves.",
56
+ "Robot dancing in times square.",
57
+ "Clown fish swimming through the coral reef.",
58
+ "Pikachu snowboarding.",
59
+ "With the style of van gogh, A young couple dances under the moonlight by the lake.",
60
+ "A young woman with glasses is jogging in the park wearing a pink headband.",
61
+ "Impressionist style, a yellow rubber duck floating on the wave on the sunset",
62
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
63
+ "With the style of low-poly game art, A majestic, white horse gallops gracefully across a moonlit beach.",
64
+ ]
65
+
66
+ examples = [[i, 7.5, 4, 16, 16] for i in example_txt]
67
+
68
+ @spaces.GPU(duration=300)
69
+ @torch.inference_mode()
70
+ def generate(
71
+ prompt: str,
72
+ guidance_scale: float = 7.5,
73
+ num_inference_steps: int = 4,
74
+ num_frames: int = 16,
75
+ fps: int = 16,
76
+ seed: int = 0,
77
+ randomize_seed: bool = False,
78
+ ):
79
+
80
+ seed = int(randomize_seed_fn(seed, randomize_seed))
81
+ result = pipeline(
82
+ prompt=prompt,
83
+ frames=num_frames,
84
+ fps=fps,
85
+ guidance_scale=guidance_scale,
86
+ num_inference_steps=num_inference_steps,
87
+ num_videos_per_prompt=1,
88
+ )
89
+
90
+ torch.cuda.empty_cache()
91
+ tmp_save_path = "tmp.mp4"
92
+ root_path = "./videos/"
93
+ os.makedirs(root_path, exist_ok=True)
94
+ video_save_path = os.path.join(root_path, tmp_save_path)
95
+
96
+ save_video(result[0], video_save_path, fps=fps)
97
+ display_model_info = f"Video size: {num_frames}x320x512, Sampling Step: {num_inference_steps}, Guidance Scale: {guidance_scale}"
98
+ return video_save_path, prompt, display_model_info, seed
99
+
100
+
101
+ block_css = """
102
+ #buttons button {
103
+ min-width: min(120px,100%);
104
+ }
105
+ """
106
+
107
+
108
+ if __name__ == "__main__":
109
+ device = torch.device("cuda:0")
110
+
111
  config = OmegaConf.load("configs/inference_t2v_512_v2.0.yaml")
112
  model_config = config.pop("model", OmegaConf.create())
113
  pretrained_t2v = instantiate_from_config(model_config)
 
140
  collapse_lora(unet, lora_manager.unet_replace_modules)
141
  monkeypatch_remove_lora(unet)
142
 
 
 
143
  pretrained_t2v.model.diffusion_model = unet
144
  scheduler = T2VTurboScheduler(
145
  linear_start=model_config["params"]["linear_start"],
 
148
  pipeline = T2VTurboVC2Pipeline(pretrained_t2v, scheduler, model_config)
149
 
150
  pipeline.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ demo = gr.Interface(
153
+ fn=generate,
154
+ inputs=[
155
+ Textbox(label="", placeholder="Please enter your prompt. \n"),
156
+ gr.Slider(
157
+ label="Guidance scale",
158
+ minimum=2,
159
+ maximum=14,
160
+ step=0.1,
161
+ value=7.5,
162
+ ),
163
+ gr.Slider(
164
+ label="Number of inference steps",
165
+ minimum=1,
166
+ maximum=8,
167
+ step=1,
168
+ value=4,
169
+ ),
170
+ gr.Slider(
171
+ label="Number of Video Frames",
172
+ minimum=16,
173
+ maximum=48,
174
+ step=8,
175
+ value=16,
176
+ ),
177
+ gr.Slider(
178
+ label="FPS",
179
+ minimum=8,
180
+ maximum=32,
181
+ step=4,
182
+ value=16,
183
+ ),
184
+ gr.Slider(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  label="Seed",
186
  minimum=0,
187
  maximum=MAX_SEED,
188
  step=1,
189
  value=0,
190
  randomize=True,
191
+ ),
192
+ gr.Checkbox(label="Randomize seed", value=True),
193
+ ],
194
+ outputs=[
195
+ gr.Video(label="Generated Video", width=512, height=320, interactive=False, autoplay=True),
196
+ Textbox(label="input prompt"),
197
+ Textbox(label="model info"),
198
+ gr.Slider(label="seed"),
199
+ ],
200
+ description=DESCRIPTION,
201
+ theme=gr.themes.Default(),
202
+ css=block_css,
203
+ examples=examples,
204
+ cache_examples=False,
205
+ concurrency_limit=10,
206
+ )
207
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -5,7 +5,6 @@ transformers==4.40.0
5
  accelerate==0.29.3
6
  imageio==2.34.0
7
  decord==0.6.0
8
- gradio==3.48.0
9
  opencv-python
10
  spaces
11
  einops
 
5
  accelerate==0.29.3
6
  imageio==2.34.0
7
  decord==0.6.0
 
8
  opencv-python
9
  spaces
10
  einops
style.css DELETED
@@ -1,16 +0,0 @@
1
- h1 {
2
- text-align: center;
3
- }
4
-
5
- #duplicate-button {
6
- margin: auto;
7
- color: #fff;
8
- background: #1565c0;
9
- border-radius: 100vh;
10
- }
11
-
12
- #component-0 {
13
- max-width: 830px;
14
- margin: auto;
15
- padding-top: 1.5rem;
16
- }