DiffIR2VR / app.py
jimmycv07's picture
uploaded videos from users
2ebea36
raw
history blame
No virus
11.7 kB
import os
import cv2
import torch
import spaces
import imageio
import numpy as np
import gradio as gr
torch.jit.script = lambda f: f
import argparse
from utils.batch_inference import (
BSRInferenceLoop, BIDInferenceLoop
)
# import subprocess
# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_example(task):
case = {
"dn": [
['examples/bus.mp4',],
['examples/koala.mp4',],
['examples/flamingo.mp4',],
['examples/rhino.mp4',],
['examples/elephant.mp4',],
['examples/sheep.mp4',],
['examples/dog-agility.mp4',],
# ['examples/dog-gooses.mp4',],
],
"sr": [
['examples/bus_sr.mp4',],
['examples/koala_sr.mp4',],
['examples/flamingo_sr.mp4',],
['examples/rhino_sr.mp4',],
['examples/elephant_sr.mp4',],
['examples/sheep_sr.mp4',],
['examples/dog-agility_sr.mp4',],
# ['examples/dog-gooses_sr.mp4',],
]
}
return case[task]
def update_prompt(input_video):
video_name = input_video.split('/')[-1]
return set_default_prompt(video_name)
# Map videos to corresponding images
video_to_image = {
'bus.mp4': ['examples_frames/bus'],
'koala.mp4': ['examples_frames/koala'],
'dog-gooses.mp4': ['examples_frames/dog-gooses'],
'flamingo.mp4': ['examples_frames/flamingo'],
'rhino.mp4': ['examples_frames/rhino'],
'elephant.mp4': ['examples_frames/elephant'],
'sheep.mp4': ['examples_frames/sheep'],
'dog-agility.mp4': ['examples_frames/dog-agility'],
'bus_sr.mp4': ['examples_frames/bus_sr'],
'koala_sr.mp4': ['examples_frames/koala_sr'],
'dog-gooses_sr.mp4': ['examples_frames/dog_gooses_sr'],
'flamingo_sr.mp4': ['examples_frames/flamingo_sr'],
'rhino_sr.mp4': ['examples_frames/rhino_sr'],
'elephant_sr.mp4': ['examples_frames/elephant_sr'],
'sheep_sr.mp4': ['examples_frames/sheep_sr'],
'dog-agility_sr.mp4': ['examples_frames/dog-agility_sr'],
}
def images_to_video(image_list, output_path, fps=10):
# Convert PIL Images to numpy arrays
frames = [np.array(img).astype(np.uint8) for img in image_list]
frames = frames[:20]
# Create video writer
writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
for frame in frames:
writer.append_data(frame)
writer.close()
def video2frames(video_path):
# Open the video file
video = cv2.VideoCapture(video_path)
img_path = video_path[:-4]
# Initialize frame counter
frame_count = 0
os.makedirs(img_path, exist_ok=True)
while True:
# Read a frame from the video
ret, frame = video.read()
# If the frame was not successfully read, then we have reached the end of the video
if not ret:
break
# Write the frame to a JPG file
frame_file = f"{img_path}/{frame_count:05}.jpg"
cv2.imwrite(frame_file, frame)
# Increment the frame counter
frame_count += 1
# Release the video file
video.release()
return img_path
@spaces.GPU(duration=120)
def DiffBIR_restore(input_video, prompt, sr_ratio, n_frames, n_steps, guidance_scale, seed, n_prompt, task):
video_name = input_video.split('/')[-1]
if video_name in video_to_image:
frames_path = video_to_image[video_name][0]
else:
frames_path = video2frames(input_video)
print(f"[INFO] input_video: {input_video}")
print(f"[INFO] Frames path: {frames_path}")
args = argparse.Namespace()
# args.task = True, choices=["sr", "dn", "fr", "fr_bg"]
args.task = task
args.upscale = sr_ratio
### sampling parameters
args.steps = n_steps
args.better_start = True
args.tiled = False
args.tile_size = 512
args.tile_stride = 256
args.pos_prompt = prompt
args.neg_prompt = n_prompt
args.cfg_scale = guidance_scale
### input parameters
args.input = frames_path
args.n_samples = 1
args.batch_size = 10
args.final_size = (480, 854)
args.config = "configs/inference/my_cldm.yaml"
### guidance parameters
args.guidance = False
args.g_loss = "w_mse"
args.g_scale = 0.0
args.g_start = 1001
args.g_stop = -1
args.g_space = "latent"
args.g_repeat = 1
### output parameters
args.output = " "
### common parameters
args.seed = seed
args.device = "cuda"
args.n_frames = n_frames
### latent control parameters
args.warp_period = [0, 0.1]
args.merge_period = [0, 0]
args.ToMe_period = [0, 1]
args.merge_ratio = [0.6, 0]
if args.task == "sr":
restored_vid_path = BSRInferenceLoop(args).run()
elif args.task == "dn":
restored_vid_path = BIDInferenceLoop(args).run()
torch.cuda.empty_cache()
return restored_vid_path
########
# demo #
########
intro = """
<div style="text-align:center">
<h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
DiffIR2VR - <small>Zero-Shot Video Restoration</small>
</h1>
<span>[<a target="_blank" href="https://jimmycv07.github.io/DiffIR2VR_web/">Project page</a>] [<a target="_blank" href="https://huggingface.co/papers/2406.06523">arXiv</a>]</span>
<div style="display:flex; justify-content: center;margin-top: 0.5em">Note that this page is a limited demo of DiffIR2VR.
For more configurations, please visit our GitHub page. The code will be released soon!</div>
<div style="display:flex; justify-content: center;margin-top: 0.5em; color: red;">For super-resolution,
it is recommended that the final frame size (original size * upscale ratio) be around 480x854,
else the demo may fail due to lengthy inference times.</div>
</div>
"""
with gr.Blocks(css="style.css") as demo:
gr.HTML(intro)
with gr.Tab(label="Super-resolution with DiffBIR"):
with gr.Row():
input_video = gr.Video(label="Input Video")
output_video = gr.Video(label="Restored Video", interactive=False)
with gr.Row():
run_button = gr.Button("Restore your video !", visible=True)
with gr.Accordion('Advanced options', open=False):
prompt = gr.Textbox(
label="Prompt",
max_lines=1,
placeholder="describe your video content"
# value="bear, Van Gogh Style"
)
sr_ratio = gr.Slider(label='Upscale ratio',
minimum=1,
maximum=16,
value=4,
step=0.5)
n_frames = gr.Slider(label='Frames',
minimum=1,
maximum=60,
value=10,
step=1)
n_steps = gr.Slider(label='Steps',
minimum=1,
maximum=100,
value=5,
step=1)
guidance_scale = gr.Slider(label='Guidance Scale',
minimum=0.1,
maximum=30.0,
value=4.0,
step=0.1)
seed = gr.Slider(label='Seed',
minimum=-1,
maximum=1000,
step=1,
randomize=True)
n_prompt = gr.Textbox(
label='Negative Prompt',
value="low quality, blurry, low-resolution, noisy, unsharp, weird textures"
)
task = gr.Textbox(value="sr", visible=False)
# input_video.change(
# fn = update_prompt,
# inputs = [input_video],
# outputs = [prompt],
# queue = False)
run_button.click(fn = DiffBIR_restore,
inputs = [input_video,
prompt,
sr_ratio,
n_frames,
n_steps,
guidance_scale,
seed,
n_prompt,
task
],
outputs = [output_video]
)
gr.Examples(
examples=get_example("sr"),
label='Examples',
inputs=[input_video],
outputs=[output_video],
examples_per_page=7
)
with gr.Tab(label="Denoise with DiffBIR"):
with gr.Row():
input_video = gr.Video(label="Input Video")
output_video = gr.Video(label="Restored Video", interactive=False)
with gr.Row():
run_button = gr.Button("Restore your video !", visible=True)
with gr.Accordion('Advanced options', open=False):
prompt = gr.Textbox(
label="Prompt",
max_lines=1,
placeholder="describe your video content"
# value="bear, Van Gogh Style"
)
n_frames = gr.Slider(label='Frames',
minimum=1,
maximum=60,
value=10,
step=1)
n_steps = gr.Slider(label='Steps',
minimum=1,
maximum=100,
value=5,
step=1)
guidance_scale = gr.Slider(label='Guidance Scale',
minimum=0.1,
maximum=30.0,
value=4.0,
step=0.1)
seed = gr.Slider(label='Seed',
minimum=-1,
maximum=1000,
step=1,
randomize=True)
n_prompt = gr.Textbox(
label='Negative Prompt',
value="low quality, blurry, low-resolution, noisy, unsharp, weird textures"
)
task = gr.Textbox(value="dn", visible=False)
sr_ratio = gr.Number(value=1, visible=False)
# input_video.change(
# fn = update_prompt,
# inputs = [input_video],
# outputs = [prompt],
# queue = False)
run_button.click(fn = DiffBIR_restore,
inputs = [input_video,
prompt,
sr_ratio,
n_frames,
n_steps,
guidance_scale,
seed,
n_prompt,
task
],
outputs = [output_video]
)
gr.Examples(
examples=get_example("dn"),
label='Examples',
inputs=[input_video],
outputs=[output_video],
examples_per_page=7
)
demo.queue()
demo.launch()