multimodalart HF Staff commited on
Commit
ecda25f
·
verified ·
1 Parent(s): c4497fb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -0
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import sys
5
+ import subprocess
6
+ import tempfile
7
+ import numpy as np
8
+ import spaces
9
+ from PIL import Image
10
+
11
+ subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
+
13
+ # --- 1. Initial Setup: Auto-Download Git Repo and Model Weights ---
14
+
15
+ # Define paths
16
+ REPO_PATH = "LongCat-Video"
17
+ CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")
18
+
19
+ # Clone the repository if it doesn't exist
20
+ if not os.path.exists(REPO_PATH):
21
+ print(f"Cloning LongCat-Video repository to '{REPO_PATH}'...")
22
+ try:
23
+ subprocess.run(
24
+ ["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH],
25
+ check=True,
26
+ capture_output=True
27
+ )
28
+ print("Repository cloned successfully.")
29
+ except subprocess.CalledProcessError as e:
30
+ print(f"Error cloning repository: {e.stderr.decode()}")
31
+ sys.exit(1)
32
+
33
+ # Add the cloned repository to the Python path to allow imports
34
+ sys.path.insert(0, os.path.abspath(REPO_PATH))
35
+
36
+ # Now that the repo is in the path, we can import its modules
37
+ from huggingface_hub import snapshot_download
38
+ from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
39
+ from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
40
+ from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
41
+ from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel
42
+ from longcat_video.context_parallel import context_parallel_util
43
+ from diffusers.utils import export_to_video
44
+
45
+ # Download model weights from Hugging Face Hub if they don't exist
46
+ if not os.path.exists(CHECKPOINT_DIR):
47
+ print(f"Downloading model weights to '{CHECKPOINT_DIR}'...")
48
+ try:
49
+ snapshot_download(
50
+ repo_id="meituan-longcat/LongCat-Video",
51
+ local_dir=CHECKPOINT_DIR,
52
+ local_dir_use_symlinks=False, # Use False for better Windows compatibility
53
+ ignore_patterns=["*.md", "*.gitattributes", "assets/*"] # ignore non-essential files
54
+ )
55
+ print("Model weights downloaded successfully.")
56
+ except Exception as e:
57
+ print(f"Error downloading model weights: {e}")
58
+ sys.exit(1)
59
+
60
+ # --- 2. Global Variables & Model Loading (in Global Context) ---
61
+
62
+ # Global placeholder for the pipeline and device configuration
63
+ pipe = None
64
+ device = "cuda" if torch.cuda.is_available() else "cpu"
65
+ torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
66
+
67
+ print("--- Initializing Models (loaded once at startup) ---")
68
+ try:
69
+ # Context parallel is not used in this single-instance demo, but the model requires the config.
70
+ cp_split_hw = context_parallel_util.get_optimal_split(1)
71
+
72
+ print("Loading tokenizer and text encoder...")
73
+ tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype)
74
+ text_encoder = UMT5EncoderModel.from_pretrained(CHECKPOINT_DIR, subfolder="text_encoder", torch_dtype=torch_dtype)
75
+
76
+ print("Loading VAE and Scheduler...")
77
+ vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype)
78
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype)
79
+
80
+ print("Loading DiT model...")
81
+ dit = LongCatVideoTransformer3DModel.from_pretrained(CHECKPOINT_DIR, subfolder="dit", cp_split_hw=cp_split_hw, torch_dtype=torch_dtype)
82
+
83
+ print("Creating LongCatVideoPipeline...")
84
+ pipe = LongCatVideoPipeline(
85
+ tokenizer=tokenizer,
86
+ text_encoder=text_encoder,
87
+ vae=vae,
88
+ scheduler=scheduler,
89
+ dit=dit,
90
+ )
91
+ pipe.to(device)
92
+
93
+ print("Loading LoRA weights for optional modes...")
94
+ cfg_step_lora_path = os.path.join(CHECKPOINT_DIR, 'lora/cfg_step_lora.safetensors')
95
+ pipe.dit.load_lora(cfg_step_lora_path, 'cfg_step_lora')
96
+
97
+ refinement_lora_path = os.path.join(CHECKPOINT_DIR, 'lora/refinement_lora.safetensors')
98
+ pipe.dit.load_lora(refinement_lora_path, 'refinement_lora')
99
+
100
+ print("--- Models loaded successfully and are ready for inference. ---")
101
+
102
+ except Exception as e:
103
+ print("--- FATAL ERROR: Failed to load models. ---")
104
+ print(f"Details: {e}")
105
+ # The app will still run, but generation will fail with an error message.
106
+ pipe = None
107
+
108
+
109
+ # --- 3. Generation Logic ---
110
+
111
+ def torch_gc():
112
+ """Helper function to clean up GPU memory."""
113
+ if torch.cuda.is_available():
114
+ torch.cuda.empty_cache()
115
+ torch.cuda.ipc_collect()
116
+
117
+ @spaces.GPU(duration=500)
118
+ def generate_video(
119
+ mode,
120
+ prompt,
121
+ neg_prompt,
122
+ image,
123
+ height, width, resolution,
124
+ seed,
125
+ use_distill,
126
+ use_refine,
127
+ progress=gr.Progress(track_ τότε=True)
128
+ ):
129
+ """
130
+ Universal video generation function.
131
+ """
132
+ if pipe is None:
133
+ raise gr.Error("Models failed to load. Please check the console output for errors and restart the app.")
134
+
135
+ generator = torch.Generator(device=device).manual_seed(int(seed))
136
+
137
+ # --- Stage 1: Base Generation (Standard or Distill) ---
138
+ progress(0, desc="Starting Stage 1: Base Generation")
139
+
140
+ num_frames = 93 # Default from demo scripts
141
+ is_distill = use_distill or use_refine # Refinement requires a distilled video as input
142
+
143
+ if is_distill:
144
+ pipe.dit.enable_loras(['cfg_step_lora'])
145
+ num_inference_steps = 16
146
+ guidance_scale = 1.0
147
+ current_neg_prompt = ""
148
+ else:
149
+ num_inference_steps = 50
150
+ guidance_scale = 4.0
151
+ current_neg_prompt = neg_prompt
152
+
153
+ if mode == "t2v":
154
+ output = pipe.generate_t2v(
155
+ prompt=prompt,
156
+ negative_prompt=current_neg_prompt,
157
+ height=height,
158
+ width=width,
159
+ num_frames=num_frames,
160
+ num_inference_steps=num_inference_steps,
161
+ use_distill=is_distill,
162
+ guidance_scale=guidance_scale,
163
+ generator=generator,
164
+ )[0]
165
+ elif mode == "i2v":
166
+ pil_image = Image.fromarray(image)
167
+ output = pipe.generate_i2v(
168
+ image=pil_image,
169
+ prompt=prompt,
170
+ negative_prompt=current_neg_prompt,
171
+ resolution=resolution,
172
+ num_frames=num_frames,
173
+ num_inference_steps=num_inference_steps,
174
+ use_distill=is_distill,
175
+ guidance_scale=guidance_scale,
176
+ generator=generator,
177
+ )[0]
178
+
179
+ if is_distill:
180
+ pipe.dit.disable_all_loras()
181
+
182
+ torch_gc()
183
+
184
+ # --- Stage 2: Refinement (Optional) ---
185
+ if use_refine:
186
+ progress(0.5, desc="Starting Stage 2: Refinement")
187
+
188
+ pipe.dit.enable_loras(['refinement_lora'])
189
+ pipe.dit.enable_bsa()
190
+
191
+ stage1_video_pil = [(frame * 255).astype(np.uint8) for frame in output]
192
+ stage1_video_pil = [Image.fromarray(img) for img in stage1_video_pil]
193
+
194
+ refine_image = Image.fromarray(image) if mode == 'i2v' else None
195
+
196
+ output = pipe.generate_refine(
197
+ image=refine_image,
198
+ prompt=prompt,
199
+ stage1_video=stage1_video_pil,
200
+ num_cond_frames=1 if mode == 'i2v' else 0,
201
+ num_inference_steps=50,
202
+ generator=generator,
203
+ )[0]
204
+
205
+ pipe.dit.disable_all_loras()
206
+ pipe.dit.disable_bsa()
207
+ torch_gc()
208
+
209
+ # --- Post-processing and Output ---
210
+ progress(1.0, desc="Exporting video")
211
+
212
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video_file:
213
+ fps = 30 if use_refine else 15
214
+ export_to_video(output, temp_video_file.name, fps=fps)
215
+ return temp_video_file.name
216
+
217
+ # --- 4. Gradio UI Definition ---
218
+
219
+ with gr.Blocks(css="style.css") as demo:
220
+ gr.Markdown("# 🎬 LongCat-Video Demo")
221
+ gr.Markdown(
222
+ "A one-click Gradio interface for LongCat-Video. "
223
+ "The first time you run this, it will automatically clone the official repository and download the model weights."
224
+ )
225
+
226
+ with gr.Tabs() as tabs:
227
+ with gr.TabItem("Text-to-Video", id=0):
228
+ mode_t2v = gr.State("t2v")
229
+ with gr.Row():
230
+ with gr.Column(scale=2):
231
+ prompt_t2v = gr.Textbox(label="Prompt", lines=4, placeholder="A cinematic shot of a Corgi walking on the beach.")
232
+ neg_prompt_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles")
233
+ with gr.Row():
234
+ height_t2v = gr.Slider(label="Height", minimum=256, maximum=1024, value=480, step=64)
235
+ width_t2v = gr.Slider(label="Width", minimum=256, maximum=1024, value=832, step=64)
236
+ with gr.Row():
237
+ seed_t2v = gr.Number(label="Seed", value=42, precision=0)
238
+ distill_t2v = gr.Checkbox(label="Use Distill Mode", value=False, info="Faster, lower quality base generation.")
239
+ refine_t2v = gr.Checkbox(label="Use Refine Mode", value=False, info="Higher quality & resolution, but slower. Uses Distill mode for its first stage.")
240
+
241
+ t2v_button = gr.Button("Generate Video", variant="primary")
242
+ with gr.Column(scale=3):
243
+ video_output_t2v = gr.Video(label="Generated Video", interactive=False)
244
+
245
+ with gr.TabItem("Image-to-Video", id=1):
246
+ mode_i2v = gr.State("i2v")
247
+ with gr.Row():
248
+ with gr.Column(scale=2):
249
+ image_i2v = gr.Image(type="numpy", label="Input Image")
250
+ prompt_i2v = gr.Textbox(label="Prompt", lines=4, placeholder="The cat in the image wags its tail and blinks.")
251
+ neg_prompt_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles, watermark")
252
+ resolution_i2v = gr.Dropdown(label="Resolution", choices=["480p", "720p"], value="480p")
253
+ with gr.Row():
254
+ seed_i2v = gr.Number(label="Seed", value=42, precision=0)
255
+ distill_i2v = gr.Checkbox(label="Use Distill Mode", value=False, info="Faster, lower quality base generation.")
256
+ refine_i2v = gr.Checkbox(label="Use Refine Mode", value=False, info="Higher quality & resolution, but slower. Uses Distill mode for its first stage.")
257
+
258
+ i2v_button = gr.Button("Generate Video", variant="primary")
259
+ with gr.Column(scale=3):
260
+ video_output_i2v = gr.Video(label="Generated Video", interactive=False)
261
+
262
+ # --- Event Handlers ---
263
+ t2v_inputs = [
264
+ mode_t2v, prompt_t2v, neg_prompt_t2v,
265
+ gr.State(None), # Placeholder for image
266
+ height_t2v, width_t2v,
267
+ gr.State(None), # Placeholder for resolution
268
+ seed_t2v, distill_t2v, refine_t2v
269
+ ]
270
+ t2v_button.click(fn=generate_video, inputs=t2v_inputs, outputs=video_output_t2v)
271
+
272
+ i2v_inputs = [
273
+ mode_i2v, prompt_i2v, neg_prompt_i2v, image_i2v,
274
+ gr.State(None), gr.State(None), # Placeholders for height/width
275
+ resolution_i2v,
276
+ seed_i2v, distill_i2v, refine_i2v
277
+ ]
278
+ i2v_button.click(fn=generate_video, inputs=i2v_inputs, outputs=video_output_i2v)
279
+
280
+ # --- 5. Launch the App ---
281
+ if __name__ == "__main__":
282
+ demo.launch()