ford442 commited on
Commit
e609c83
·
verified ·
1 Parent(s): b112fda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -17,6 +17,8 @@ import torch
17
  import cv2
18
  import gc
19
 
 
 
20
  torch.backends.cuda.matmul.allow_tf32 = False
21
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
22
  torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
@@ -40,6 +42,8 @@ import shutil
40
 
41
  MAX_SEED = np.iinfo(np.int32).max
42
 
 
 
43
  #import diffusers
44
  from diffusers import StableDiffusionXLImg2ImgPipeline, AutoencoderKL
45
  print("Loading SDXL Image-to-Image pipeline...")
@@ -117,16 +121,20 @@ def get_duration(*args, **kwargs):
117
  return 90
118
 
119
 
120
- @spaces.GPU()
121
- def enhance_frame(image_to_enhance: Image.Image):
122
  try:
123
  print("Moving enhancer pipeline to GPU...")
124
  seed = random.randint(0, MAX_SEED)
125
  generator = torch.Generator(device='cpu').manual_seed(seed)
126
  enhancer_pipeline.to("cuda",torch.bfloat16)
127
- refine_prompt = "cinematic, high detail, sharp focus, 8k, professional photography"
128
- enhanced_image = enhancer_pipeline(prompt=refine_prompt, image=image_to_enhance, strength=0.125, generator=generator, num_inference_steps=100).images[0]
129
  print("Frame enhancement successful.")
 
 
 
 
130
  except Exception as e:
131
  print(f"Error during frame enhancement: {e}")
132
  gr.Warning("Frame enhancement failed. Using original frame.")
@@ -139,7 +147,7 @@ def enhance_frame(image_to_enhance: Image.Image):
139
  return enhanced_image
140
 
141
 
142
- def use_last_frame_as_input(video_filepath, do_enhance):
143
  if not video_filepath or not os.path.exists(video_filepath):
144
  gr.Warning("No video clip available.")
145
  return None, gr.update()
@@ -155,7 +163,7 @@ def use_last_frame_as_input(video_filepath, do_enhance):
155
  print("Displaying original last frame...")
156
  yield pil_image, gr.update()
157
  if do_enhance:
158
- enhanced_image = enhance_frame(pil_image)
159
  # 2. Yield the enhanced frame and switch the tab
160
  print("Displaying enhanced frame and switching tab...")
161
  yield enhanced_image, gr.update(selected="i2v_tab")
@@ -312,7 +320,7 @@ with gr.Blocks(css=css) as demo:
312
  t2v_button.click(hide_btn, outputs=[use_last_frame_button], queue=False).then(fn=generate, inputs=t2v_inputs, outputs=gen_outputs, api_name="text_to_video")
313
  i2v_button.click(hide_btn, outputs=[use_last_frame_button], queue=False).then(fn=generate, inputs=i2v_inputs, outputs=gen_outputs, api_name="image_to_video")
314
  v2v_button.click(hide_btn, outputs=[use_last_frame_button], queue=False).then(fn=generate, inputs=v2v_inputs, outputs=gen_outputs, api_name="video_to_video")
315
- use_last_frame_button.click(fn=use_last_frame_as_input, inputs=[output_video,enhance_checkbox], outputs=[image_i2v, tabs])
316
  stitch_button.click(fn=stitch_videos, inputs=[clips_state], outputs=[final_video_output])
317
  clear_button.click(fn=clear_clips, outputs=[clips_state, clip_counter_display, output_video, final_video_output])
318
  if __name__ == "__main__":
 
17
  import cv2
18
  import gc
19
 
20
+ from image_gen_aux import UpscaleWithModel
21
+
22
  torch.backends.cuda.matmul.allow_tf32 = False
23
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
24
  torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
 
42
 
43
  MAX_SEED = np.iinfo(np.int32).max
44
 
45
+ upscaler = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(torch.device("cuda:0"))
46
+
47
  #import diffusers
48
  from diffusers import StableDiffusionXLImg2ImgPipeline, AutoencoderKL
49
  print("Loading SDXL Image-to-Image pipeline...")
 
121
  return 90
122
 
123
 
124
+ @spaces.GPU(duration=45)
125
+ def enhance_frame(prompt, image_to_enhance: Image.Image):
126
  try:
127
  print("Moving enhancer pipeline to GPU...")
128
  seed = random.randint(0, MAX_SEED)
129
  generator = torch.Generator(device='cpu').manual_seed(seed)
130
  enhancer_pipeline.to("cuda",torch.bfloat16)
131
+ refine_prompt = prompt +" high detail, sharp focus, 8k, professional"
132
+ enhanced_image = enhancer_pipeline(prompt=refine_prompt, image=image_to_enhance, strength=0.1, generator=generator, num_inference_steps=220).images[0]
133
  print("Frame enhancement successful.")
134
+ print("Doing super-resolution.")
135
+ with torch.no_grad():
136
+ upscale = upscaler(enhanced_image, tiling=True, tile_width=1024, tile_height=1024)
137
+ enhanced_image = upscale.resize((upscale.width // 4, upscale.height // 4), Image.LANCZOS)
138
  except Exception as e:
139
  print(f"Error during frame enhancement: {e}")
140
  gr.Warning("Frame enhancement failed. Using original frame.")
 
147
  return enhanced_image
148
 
149
 
150
+ def use_last_frame_as_input(prompt, video_filepath, do_enhance):
151
  if not video_filepath or not os.path.exists(video_filepath):
152
  gr.Warning("No video clip available.")
153
  return None, gr.update()
 
163
  print("Displaying original last frame...")
164
  yield pil_image, gr.update()
165
  if do_enhance:
166
+ enhanced_image = enhance_frame(prompt, pil_image)
167
  # 2. Yield the enhanced frame and switch the tab
168
  print("Displaying enhanced frame and switching tab...")
169
  yield enhanced_image, gr.update(selected="i2v_tab")
 
320
  t2v_button.click(hide_btn, outputs=[use_last_frame_button], queue=False).then(fn=generate, inputs=t2v_inputs, outputs=gen_outputs, api_name="text_to_video")
321
  i2v_button.click(hide_btn, outputs=[use_last_frame_button], queue=False).then(fn=generate, inputs=i2v_inputs, outputs=gen_outputs, api_name="image_to_video")
322
  v2v_button.click(hide_btn, outputs=[use_last_frame_button], queue=False).then(fn=generate, inputs=v2v_inputs, outputs=gen_outputs, api_name="video_to_video")
323
+ use_last_frame_button.click(fn=use_last_frame_as_input, inputs=[i2v_prompt,output_video,enhance_checkbox], outputs=[image_i2v, tabs])
324
  stitch_button.click(fn=stitch_videos, inputs=[clips_state], outputs=[final_video_output])
325
  clear_button.click(fn=clear_clips, outputs=[clips_state, clip_counter_display, output_video, final_video_output])
326
  if __name__ == "__main__":