Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import cv2 | |
import torch | |
import spaces | |
import imageio | |
import numpy as np | |
import gradio as gr | |
torch.jit.script = lambda f: f | |
import argparse | |
from utils.batch_inference import ( | |
BSRInferenceLoop, BIDInferenceLoop | |
) | |
# import subprocess | |
# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def get_example(task): | |
case = { | |
"dn": [ | |
['examples/bus.mp4',], | |
['examples/koala.mp4',], | |
['examples/flamingo.mp4',], | |
['examples/rhino.mp4',], | |
['examples/elephant.mp4',], | |
['examples/sheep.mp4',], | |
['examples/dog-agility.mp4',], | |
# ['examples/dog-gooses.mp4',], | |
], | |
"sr": [ | |
['examples/bus_sr.mp4',], | |
['examples/koala_sr.mp4',], | |
['examples/flamingo_sr.mp4',], | |
['examples/rhino_sr.mp4',], | |
['examples/elephant_sr.mp4',], | |
['examples/sheep_sr.mp4',], | |
['examples/dog-agility_sr.mp4',], | |
# ['examples/dog-gooses_sr.mp4',], | |
] | |
} | |
return case[task] | |
def update_prompt(input_video): | |
video_name = input_video.split('/')[-1] | |
return set_default_prompt(video_name) | |
# Map videos to corresponding images | |
video_to_image = { | |
'bus.mp4': ['examples_frames/bus'], | |
'koala.mp4': ['examples_frames/koala'], | |
'dog-gooses.mp4': ['examples_frames/dog-gooses'], | |
'flamingo.mp4': ['examples_frames/flamingo'], | |
'rhino.mp4': ['examples_frames/rhino'], | |
'elephant.mp4': ['examples_frames/elephant'], | |
'sheep.mp4': ['examples_frames/sheep'], | |
'dog-agility.mp4': ['examples_frames/dog-agility'], | |
'bus_sr.mp4': ['examples_frames/bus_sr'], | |
'koala_sr.mp4': ['examples_frames/koala_sr'], | |
'dog-gooses_sr.mp4': ['examples_frames/dog_gooses_sr'], | |
'flamingo_sr.mp4': ['examples_frames/flamingo_sr'], | |
'rhino_sr.mp4': ['examples_frames/rhino_sr'], | |
'elephant_sr.mp4': ['examples_frames/elephant_sr'], | |
'sheep_sr.mp4': ['examples_frames/sheep_sr'], | |
'dog-agility_sr.mp4': ['examples_frames/dog-agility_sr'], | |
} | |
def images_to_video(image_list, output_path, fps=10): | |
# Convert PIL Images to numpy arrays | |
frames = [np.array(img).astype(np.uint8) for img in image_list] | |
frames = frames[:20] | |
# Create video writer | |
writer = imageio.get_writer(output_path, fps=fps, codec='libx264') | |
for frame in frames: | |
writer.append_data(frame) | |
writer.close() | |
def video2frames(video_path): | |
# Open the video file | |
video = cv2.VideoCapture(video_path) | |
img_path = video_path[:-4] | |
# Initialize frame counter | |
frame_count = 0 | |
os.makedirs(img_path, exist_ok=True) | |
while True: | |
# Read a frame from the video | |
ret, frame = video.read() | |
# If the frame was not successfully read, then we have reached the end of the video | |
if not ret: | |
break | |
# Write the frame to a JPG file | |
frame_file = f"{img_path}/{frame_count:05}.jpg" | |
cv2.imwrite(frame_file, frame) | |
# Increment the frame counter | |
frame_count += 1 | |
# Release the video file | |
video.release() | |
return img_path | |
def DiffBIR_restore(input_video, prompt, sr_ratio, n_frames, n_steps, guidance_scale, seed, n_prompt, task): | |
video_name = input_video.split('/')[-1] | |
if video_name in video_to_image: | |
frames_path = video_to_image[video_name][0] | |
else: | |
frames_path = video2frames(input_video) | |
print(f"[INFO] input_video: {input_video}") | |
print(f"[INFO] Frames path: {frames_path}") | |
args = argparse.Namespace() | |
# args.task = True, choices=["sr", "dn", "fr", "fr_bg"] | |
args.task = task | |
args.upscale = sr_ratio | |
### sampling parameters | |
args.steps = n_steps | |
args.better_start = True | |
args.tiled = False | |
args.tile_size = 512 | |
args.tile_stride = 256 | |
args.pos_prompt = prompt | |
args.neg_prompt = n_prompt | |
args.cfg_scale = guidance_scale | |
### input parameters | |
args.input = frames_path | |
args.n_samples = 1 | |
args.batch_size = 10 | |
args.final_size = (480, 854) | |
args.config = "configs/inference/my_cldm.yaml" | |
### guidance parameters | |
args.guidance = False | |
args.g_loss = "w_mse" | |
args.g_scale = 0.0 | |
args.g_start = 1001 | |
args.g_stop = -1 | |
args.g_space = "latent" | |
args.g_repeat = 1 | |
### output parameters | |
args.output = " " | |
### common parameters | |
args.seed = seed | |
args.device = "cuda" | |
args.n_frames = n_frames | |
### latent control parameters | |
args.warp_period = [0, 0.1] | |
args.merge_period = [0, 0] | |
args.ToMe_period = [0, 1] | |
args.merge_ratio = [0.6, 0] | |
if args.task == "sr": | |
restored_vid_path = BSRInferenceLoop(args).run() | |
elif args.task == "dn": | |
restored_vid_path = BIDInferenceLoop(args).run() | |
torch.cuda.empty_cache() | |
return restored_vid_path | |
######## | |
# demo # | |
######## | |
intro = """ | |
<div style="text-align:center"> | |
<h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;"> | |
DiffIR2VR - <small>Zero-Shot Video Restoration</small> | |
</h1> | |
<span>[<a target="_blank" href="https://jimmycv07.github.io/DiffIR2VR_web/">Project page</a>] [<a target="_blank" href="https://huggingface.co/papers/2406.06523">arXiv</a>]</span> | |
<div style="display:flex; justify-content: center;margin-top: 0.5em">Note that this page is a limited demo of DiffIR2VR. | |
For more configurations, please visit our GitHub page. The code will be released soon!</div> | |
<div style="display:flex; justify-content: center;margin-top: 0.5em; color: red;">For super-resolution, | |
it is recommended that the final frame size (original size * upscale ratio) be around 480x854, | |
else the demo may fail due to lengthy inference times.</div> | |
</div> | |
""" | |
with gr.Blocks(css="style.css") as demo: | |
gr.HTML(intro) | |
with gr.Tab(label="Super-resolution with DiffBIR"): | |
with gr.Row(): | |
input_video = gr.Video(label="Input Video") | |
output_video = gr.Video(label="Restored Video", interactive=False) | |
with gr.Row(): | |
run_button = gr.Button("Restore your video !", visible=True) | |
with gr.Accordion('Advanced options', open=False): | |
prompt = gr.Textbox( | |
label="Prompt", | |
max_lines=1, | |
placeholder="describe your video content" | |
# value="bear, Van Gogh Style" | |
) | |
sr_ratio = gr.Slider(label='Upscale ratio', | |
minimum=1, | |
maximum=16, | |
value=4, | |
step=0.5) | |
n_frames = gr.Slider(label='Frames', | |
minimum=1, | |
maximum=60, | |
value=10, | |
step=1) | |
n_steps = gr.Slider(label='Steps', | |
minimum=1, | |
maximum=100, | |
value=5, | |
step=1) | |
guidance_scale = gr.Slider(label='Guidance Scale', | |
minimum=0.1, | |
maximum=30.0, | |
value=4.0, | |
step=0.1) | |
seed = gr.Slider(label='Seed', | |
minimum=-1, | |
maximum=1000, | |
step=1, | |
randomize=True) | |
n_prompt = gr.Textbox( | |
label='Negative Prompt', | |
value="low quality, blurry, low-resolution, noisy, unsharp, weird textures" | |
) | |
task = gr.Textbox(value="sr", visible=False) | |
# input_video.change( | |
# fn = update_prompt, | |
# inputs = [input_video], | |
# outputs = [prompt], | |
# queue = False) | |
run_button.click(fn = DiffBIR_restore, | |
inputs = [input_video, | |
prompt, | |
sr_ratio, | |
n_frames, | |
n_steps, | |
guidance_scale, | |
seed, | |
n_prompt, | |
task | |
], | |
outputs = [output_video] | |
) | |
gr.Examples( | |
examples=get_example("sr"), | |
label='Examples', | |
inputs=[input_video], | |
outputs=[output_video], | |
examples_per_page=7 | |
) | |
with gr.Tab(label="Denoise with DiffBIR"): | |
with gr.Row(): | |
input_video = gr.Video(label="Input Video") | |
output_video = gr.Video(label="Restored Video", interactive=False) | |
with gr.Row(): | |
run_button = gr.Button("Restore your video !", visible=True) | |
with gr.Accordion('Advanced options', open=False): | |
prompt = gr.Textbox( | |
label="Prompt", | |
max_lines=1, | |
placeholder="describe your video content" | |
# value="bear, Van Gogh Style" | |
) | |
n_frames = gr.Slider(label='Frames', | |
minimum=1, | |
maximum=60, | |
value=10, | |
step=1) | |
n_steps = gr.Slider(label='Steps', | |
minimum=1, | |
maximum=100, | |
value=5, | |
step=1) | |
guidance_scale = gr.Slider(label='Guidance Scale', | |
minimum=0.1, | |
maximum=30.0, | |
value=4.0, | |
step=0.1) | |
seed = gr.Slider(label='Seed', | |
minimum=-1, | |
maximum=1000, | |
step=1, | |
randomize=True) | |
n_prompt = gr.Textbox( | |
label='Negative Prompt', | |
value="low quality, blurry, low-resolution, noisy, unsharp, weird textures" | |
) | |
task = gr.Textbox(value="dn", visible=False) | |
sr_ratio = gr.Number(value=1, visible=False) | |
# input_video.change( | |
# fn = update_prompt, | |
# inputs = [input_video], | |
# outputs = [prompt], | |
# queue = False) | |
run_button.click(fn = DiffBIR_restore, | |
inputs = [input_video, | |
prompt, | |
sr_ratio, | |
n_frames, | |
n_steps, | |
guidance_scale, | |
seed, | |
n_prompt, | |
task | |
], | |
outputs = [output_video] | |
) | |
gr.Examples( | |
examples=get_example("dn"), | |
label='Examples', | |
inputs=[input_video], | |
outputs=[output_video], | |
examples_per_page=7 | |
) | |
demo.queue() | |
demo.launch() |