alex commited on
Commit
683f192
·
1 Parent(s): 12d3925

more cleanup

Browse files
Files changed (1) hide show
  1. app.py +34 -40
app.py CHANGED
@@ -4,6 +4,8 @@ import os
4
  import subprocess
5
  import importlib, site
6
  from PIL import Image
 
 
7
 
8
  # Re-discover all .pth/.egg-link files
9
  for sitedir in site.getsitepackages():
@@ -40,6 +42,8 @@ import torch
40
  print(f"Torch version: {torch.__version__}")
41
  print(f"FlashAttention available: {flash_attention_installed}")
42
 
 
 
43
  import gradio as gr
44
  import argparse
45
  from ovi.ovi_fusion_engine import OviFusionEngine, DEFAULT_CONFIG
@@ -52,11 +56,7 @@ from ovi.utils.processing_utils import clean_text, scale_hw_to_area_divisible
52
  # Parse CLI Args
53
  # ----------------------------
54
  parser = argparse.ArgumentParser(description="Ovi Joint Video + Audio Gradio Demo")
55
- parser.add_argument(
56
- "--use_image_gen",
57
- action="store_true",
58
- help="Enable image generation UI with FluxPipeline"
59
- )
60
  parser.add_argument(
61
  "--cpu_offload",
62
  action="store_true",
@@ -99,16 +99,11 @@ snapshot_download(
99
  )
100
 
101
  # Initialize OviFusionEngine
102
- enable_cpu_offload = args.cpu_offload or args.use_image_gen
103
- use_image_gen = args.use_image_gen
104
- print(f"loading model... {enable_cpu_offload=}, {use_image_gen=} for gradio demo")
105
  DEFAULT_CONFIG['cpu_offload'] = enable_cpu_offload # always use cpu offload if image generation is enabled
106
  DEFAULT_CONFIG['mode'] = "t2v" # hardcoded since it is always cpu offloaded
107
  ovi_engine = OviFusionEngine()
108
- flux_model = None
109
- if use_image_gen:
110
- flux_model = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=torch.bfloat16)
111
- flux_model.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU VRAM
112
  print("loaded model")
113
 
114
 
@@ -170,6 +165,7 @@ def generate_video(
170
  slg_layer = 11,
171
  video_negative_prompt = "",
172
  audio_negative_prompt = "",
 
173
  progress=gr.Progress(track_tqdm=True)
174
  ):
175
  try:
@@ -178,6 +174,15 @@ def generate_video(
178
  if image is not None:
179
  image_path = image
180
 
 
 
 
 
 
 
 
 
 
181
  _, target_size = resize_for_model(image_path)
182
 
183
  video_frame_width = target_size[0]
@@ -198,8 +203,6 @@ def generate_video(
198
  audio_negative_prompt=audio_negative_prompt,
199
  )
200
 
201
- tmpfile = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
202
- output_path = tmpfile.name
203
  save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000)
204
 
205
  return output_path
@@ -208,24 +211,16 @@ def generate_video(
208
  return None
209
 
210
 
211
- def generate_image(text_prompt, image_seed, image_height, image_width):
212
- if flux_model is None:
213
- return None
214
- text_prompt = clean_text(text_prompt)
215
- print(f"Generating image with prompt='{text_prompt}', seed={image_seed}, size=({image_height},{image_width})")
216
-
217
- image_h, image_w = scale_hw_to_area_divisible(image_height, image_width, area=1024 * 1024)
218
- image = flux_model(
219
- text_prompt,
220
- height=image_h,
221
- width=image_w,
222
- guidance_scale=4.5,
223
- generator=torch.Generator().manual_seed(int(image_seed))
224
- ).images[0]
225
-
226
- tmpfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
227
- image.save(tmpfile.name)
228
- return tmpfile.name
229
 
230
  css = """
231
  #col-container {
@@ -236,6 +231,9 @@ css = """
236
 
237
  with gr.Blocks(css=css) as demo:
238
 
 
 
 
239
  with gr.Column(elem_id="col-container"):
240
  gr.HTML(
241
  """
@@ -337,13 +335,6 @@ with gr.Blocks(css=css) as demo:
337
  cache_examples=True,
338
  )
339
 
340
- if args.use_image_gen and gen_img_btn is not None:
341
- gen_img_btn.click(
342
- fn=generate_image,
343
- inputs=[image_text_prompt, image_seed, image_height, image_width],
344
- outputs=[image],
345
- )
346
-
347
  run_btn.click(
348
  fn=generate_video,
349
  inputs=[video_text_prompt, image, sample_steps],
@@ -351,4 +342,7 @@ with gr.Blocks(css=css) as demo:
351
  )
352
 
353
  if __name__ == "__main__":
354
- demo.launch(share=True)
 
 
 
 
4
  import subprocess
5
  import importlib, site
6
  from PIL import Image
7
+ import uuid
8
+ import shutil
9
 
10
  # Re-discover all .pth/.egg-link files
11
  for sitedir in site.getsitepackages():
 
42
  print(f"Torch version: {torch.__version__}")
43
  print(f"FlashAttention available: {flash_attention_installed}")
44
 
45
+ os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/processed_results"
46
+
47
  import gradio as gr
48
  import argparse
49
  from ovi.ovi_fusion_engine import OviFusionEngine, DEFAULT_CONFIG
 
56
  # Parse CLI Args
57
  # ----------------------------
58
  parser = argparse.ArgumentParser(description="Ovi Joint Video + Audio Gradio Demo")
59
+
 
 
 
 
60
  parser.add_argument(
61
  "--cpu_offload",
62
  action="store_true",
 
99
  )
100
 
101
  # Initialize OviFusionEngine
102
+ enable_cpu_offload = args.cpu_offload
103
+ print(f"loading model...")
 
104
  DEFAULT_CONFIG['cpu_offload'] = enable_cpu_offload # always use cpu offload if image generation is enabled
105
  DEFAULT_CONFIG['mode'] = "t2v" # hardcoded since it is always cpu offloaded
106
  ovi_engine = OviFusionEngine()
 
 
 
 
107
  print("loaded model")
108
 
109
 
 
165
  slg_layer = 11,
166
  video_negative_prompt = "",
167
  audio_negative_prompt = "",
168
+ session_id = None,
169
  progress=gr.Progress(track_tqdm=True)
170
  ):
171
  try:
 
174
  if image is not None:
175
  image_path = image
176
 
177
+ if session_id is None:
178
+ session_id = uuid.uuid4().hex
179
+
180
+
181
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
182
+ os.makedirs(output_dir, exist_ok=True)
183
+ output_path = os.path.join(output_dir, f"generated_video.mp4")
184
+
185
+
186
  _, target_size = resize_for_model(image_path)
187
 
188
  video_frame_width = target_size[0]
 
203
  audio_negative_prompt=audio_negative_prompt,
204
  )
205
 
 
 
206
  save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000)
207
 
208
  return output_path
 
211
  return None
212
 
213
 
214
+ def cleanup(request: gr.Request):
215
+
216
+ sid = request.session_hash
217
+ if sid:
218
+ d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
219
+ shutil.rmtree(d1, ignore_errors=True)
220
+
221
+ def start_session(request: gr.Request):
222
+
223
+ return request.session_hash
 
 
 
 
 
 
 
 
224
 
225
  css = """
226
  #col-container {
 
231
 
232
  with gr.Blocks(css=css) as demo:
233
 
234
+ session_state = gr.State()
235
+ demo.load(start_session, outputs=[session_state])
236
+
237
  with gr.Column(elem_id="col-container"):
238
  gr.HTML(
239
  """
 
335
  cache_examples=True,
336
  )
337
 
 
 
 
 
 
 
 
338
  run_btn.click(
339
  fn=generate_video,
340
  inputs=[video_text_prompt, image, sample_steps],
 
342
  )
343
 
344
  if __name__ == "__main__":
345
+ demo.unload(cleanup)
346
+ demo.queue()
347
+ demo.launch(ssr_mode=False, share=True)
348
+