Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,6 +17,8 @@ import torch
|
|
| 17 |
import cv2
|
| 18 |
import gc
|
| 19 |
|
|
|
|
|
|
|
| 20 |
torch.backends.cuda.matmul.allow_tf32 = False
|
| 21 |
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
| 22 |
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
|
@@ -40,6 +42,8 @@ import shutil
|
|
| 40 |
|
| 41 |
MAX_SEED = np.iinfo(np.int32).max
|
| 42 |
|
|
|
|
|
|
|
| 43 |
#import diffusers
|
| 44 |
from diffusers import StableDiffusionXLImg2ImgPipeline, AutoencoderKL
|
| 45 |
print("Loading SDXL Image-to-Image pipeline...")
|
|
@@ -117,16 +121,20 @@ def get_duration(*args, **kwargs):
|
|
| 117 |
return 90
|
| 118 |
|
| 119 |
|
| 120 |
-
@spaces.GPU()
|
| 121 |
-
def enhance_frame(image_to_enhance: Image.Image):
|
| 122 |
try:
|
| 123 |
print("Moving enhancer pipeline to GPU...")
|
| 124 |
seed = random.randint(0, MAX_SEED)
|
| 125 |
generator = torch.Generator(device='cpu').manual_seed(seed)
|
| 126 |
enhancer_pipeline.to("cuda",torch.bfloat16)
|
| 127 |
-
refine_prompt = "
|
| 128 |
-
enhanced_image = enhancer_pipeline(prompt=refine_prompt, image=image_to_enhance, strength=0.
|
| 129 |
print("Frame enhancement successful.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
except Exception as e:
|
| 131 |
print(f"Error during frame enhancement: {e}")
|
| 132 |
gr.Warning("Frame enhancement failed. Using original frame.")
|
|
@@ -139,7 +147,7 @@ def enhance_frame(image_to_enhance: Image.Image):
|
|
| 139 |
return enhanced_image
|
| 140 |
|
| 141 |
|
| 142 |
-
def use_last_frame_as_input(video_filepath, do_enhance):
|
| 143 |
if not video_filepath or not os.path.exists(video_filepath):
|
| 144 |
gr.Warning("No video clip available.")
|
| 145 |
return None, gr.update()
|
|
@@ -155,7 +163,7 @@ def use_last_frame_as_input(video_filepath, do_enhance):
|
|
| 155 |
print("Displaying original last frame...")
|
| 156 |
yield pil_image, gr.update()
|
| 157 |
if do_enhance:
|
| 158 |
-
enhanced_image = enhance_frame(pil_image)
|
| 159 |
# 2. Yield the enhanced frame and switch the tab
|
| 160 |
print("Displaying enhanced frame and switching tab...")
|
| 161 |
yield enhanced_image, gr.update(selected="i2v_tab")
|
|
@@ -312,7 +320,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 312 |
t2v_button.click(hide_btn, outputs=[use_last_frame_button], queue=False).then(fn=generate, inputs=t2v_inputs, outputs=gen_outputs, api_name="text_to_video")
|
| 313 |
i2v_button.click(hide_btn, outputs=[use_last_frame_button], queue=False).then(fn=generate, inputs=i2v_inputs, outputs=gen_outputs, api_name="image_to_video")
|
| 314 |
v2v_button.click(hide_btn, outputs=[use_last_frame_button], queue=False).then(fn=generate, inputs=v2v_inputs, outputs=gen_outputs, api_name="video_to_video")
|
| 315 |
-
use_last_frame_button.click(fn=use_last_frame_as_input, inputs=[output_video,enhance_checkbox], outputs=[image_i2v, tabs])
|
| 316 |
stitch_button.click(fn=stitch_videos, inputs=[clips_state], outputs=[final_video_output])
|
| 317 |
clear_button.click(fn=clear_clips, outputs=[clips_state, clip_counter_display, output_video, final_video_output])
|
| 318 |
if __name__ == "__main__":
|
|
|
|
| 17 |
import cv2
|
| 18 |
import gc
|
| 19 |
|
| 20 |
+
from image_gen_aux import UpscaleWithModel
|
| 21 |
+
|
| 22 |
torch.backends.cuda.matmul.allow_tf32 = False
|
| 23 |
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
| 24 |
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
|
|
|
| 42 |
|
| 43 |
MAX_SEED = np.iinfo(np.int32).max
|
| 44 |
|
| 45 |
+
upscaler = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(torch.device("cuda:0"))
|
| 46 |
+
|
| 47 |
#import diffusers
|
| 48 |
from diffusers import StableDiffusionXLImg2ImgPipeline, AutoencoderKL
|
| 49 |
print("Loading SDXL Image-to-Image pipeline...")
|
|
|
|
| 121 |
return 90
|
| 122 |
|
| 123 |
|
| 124 |
+
@spaces.GPU(duration=45)
|
| 125 |
+
def enhance_frame(prompt, image_to_enhance: Image.Image):
|
| 126 |
try:
|
| 127 |
print("Moving enhancer pipeline to GPU...")
|
| 128 |
seed = random.randint(0, MAX_SEED)
|
| 129 |
generator = torch.Generator(device='cpu').manual_seed(seed)
|
| 130 |
enhancer_pipeline.to("cuda",torch.bfloat16)
|
| 131 |
+
refine_prompt = prompt +" high detail, sharp focus, 8k, professional"
|
| 132 |
+
enhanced_image = enhancer_pipeline(prompt=refine_prompt, image=image_to_enhance, strength=0.1, generator=generator, num_inference_steps=220).images[0]
|
| 133 |
print("Frame enhancement successful.")
|
| 134 |
+
print("Doing super-resolution.")
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
upscale = upscaler(enhanced_image, tiling=True, tile_width=1024, tile_height=1024)
|
| 137 |
+
enhanced_image = upscale.resize((upscale.width // 4, upscale.height // 4), Image.LANCZOS)
|
| 138 |
except Exception as e:
|
| 139 |
print(f"Error during frame enhancement: {e}")
|
| 140 |
gr.Warning("Frame enhancement failed. Using original frame.")
|
|
|
|
| 147 |
return enhanced_image
|
| 148 |
|
| 149 |
|
| 150 |
+
def use_last_frame_as_input(prompt, video_filepath, do_enhance):
|
| 151 |
if not video_filepath or not os.path.exists(video_filepath):
|
| 152 |
gr.Warning("No video clip available.")
|
| 153 |
return None, gr.update()
|
|
|
|
| 163 |
print("Displaying original last frame...")
|
| 164 |
yield pil_image, gr.update()
|
| 165 |
if do_enhance:
|
| 166 |
+
enhanced_image = enhance_frame(prompt, pil_image)
|
| 167 |
# 2. Yield the enhanced frame and switch the tab
|
| 168 |
print("Displaying enhanced frame and switching tab...")
|
| 169 |
yield enhanced_image, gr.update(selected="i2v_tab")
|
|
|
|
| 320 |
t2v_button.click(hide_btn, outputs=[use_last_frame_button], queue=False).then(fn=generate, inputs=t2v_inputs, outputs=gen_outputs, api_name="text_to_video")
|
| 321 |
i2v_button.click(hide_btn, outputs=[use_last_frame_button], queue=False).then(fn=generate, inputs=i2v_inputs, outputs=gen_outputs, api_name="image_to_video")
|
| 322 |
v2v_button.click(hide_btn, outputs=[use_last_frame_button], queue=False).then(fn=generate, inputs=v2v_inputs, outputs=gen_outputs, api_name="video_to_video")
|
| 323 |
+
use_last_frame_button.click(fn=use_last_frame_as_input, inputs=[i2v_prompt,output_video,enhance_checkbox], outputs=[image_i2v, tabs])
|
| 324 |
stitch_button.click(fn=stitch_videos, inputs=[clips_state], outputs=[final_video_output])
|
| 325 |
clear_button.click(fn=clear_clips, outputs=[clips_state, clip_counter_display, output_video, final_video_output])
|
| 326 |
if __name__ == "__main__":
|