kjhlhkljh / server.py
newturok's picture
asdfas
1c5ae9c
import os
import gradio as gr
from sd_model_cfg import model_dict
from app import process, process0, process1, process2, get_frame_count, cfg_to_input
DESCRIPTION = '''
## Rerender A Video
### This space provides the function of key frame translation. Full code for full video translation will be released upon the publication of the paper.
### To avoid overload, we set limitations to the maximum frame number (8) and the maximum frame resolution (512x768).
### The running time of a video of size 512x640 is about 1 minute per keyframe under T4 GPU.
### How to use:
1. **Run 1st Key Frame**: only translate the first frame, so you can adjust the prompts/models/parameters to find your ideal output appearance before run the whole video.
2. **Run Key Frames**: translate all the key frames based on the settings of the first frame
3. **Run All**: **Run 1st Key Frame** and **Run Key Frames**
4. **Run Propagation**: propogate the key frames to other frames for full video translation. This part will be released upon the publication of the paper.
### Tips:
1. This method cannot handle large or quick motions where the optical flow is hard to estimate. **Videos with stable motions are preferred**.
2. Pixel-aware fusion may not work for large or quick motions.
3. Try different color-aware AdaIN settings and even unuse it to avoid color jittering.
4. `revAnimated_v11` model for non-photorealstic style, `realisticVisionV20_v20` model for photorealstic style.
5. To use your own SD/LoRA model, you may clone the space and specify your model with [sd_model_cfg.py](https://huggingface.co/spaces/Anonymous-sub/Rerender/blob/main/sd_model_cfg.py).
6. This method is based on the original SD model. You may need to [convert](https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py) Diffuser/Automatic1111 models to the original one.
**This code is for research purpose and non-commercial use only.**
<a href="https://huggingface.co/spaces/Anonymous-sub/Rerender?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
<img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> for no queue on your own hardware.</p>
'''
MAX_KEYFRAME = 100000000
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
input_path = gr.Video(label='Input Video',
source='upload',
format='mp4',
visible=True)
prompt = gr.Textbox(label='Prompt')
seed = gr.Slider(label='Seed',
minimum=0,
maximum=2147483647,
step=1,
value=0,
randomize=True)
run_button = gr.Button(value='Run All')
with gr.Row():
run_button1 = gr.Button(value='Run 1st Key Frame')
run_button2 = gr.Button(value='Run Key Frames')
run_button3 = gr.Button(value='Run Propagation')
with gr.Accordion('Advanced options for the 1st frame translation',
open=False):
image_resolution = gr.Slider(
label='Frame rsolution',
minimum=256,
maximum=512,
value=512,
step=64,
info='To avoid overload, maximum 512')
control_strength = gr.Slider(label='ControNet strength',
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.01)
x0_strength = gr.Slider(
label='Denoising strength',
minimum=0.00,
maximum=1.05,
value=0.75,
step=0.05,
info=('0: fully recover the input.'
'1.05: fully rerender the input.'))
color_preserve = gr.Checkbox(
label='Preserve color',
value=True,
info='Keep the color of the input video')
with gr.Row():
left_crop = gr.Slider(label='Left crop length',
minimum=0,
maximum=512,
value=0,
step=1)
right_crop = gr.Slider(label='Right crop length',
minimum=0,
maximum=512,
value=0,
step=1)
with gr.Row():
top_crop = gr.Slider(label='Top crop length',
minimum=0,
maximum=512,
value=0,
step=1)
bottom_crop = gr.Slider(label='Bottom crop length',
minimum=0,
maximum=512,
value=0,
step=1)
with gr.Row():
control_type = gr.Dropdown(['HED', 'canny'],
label='Control type',
value='HED')
low_threshold = gr.Slider(label='Canny low threshold',
minimum=1,
maximum=255,
value=100,
step=1)
high_threshold = gr.Slider(label='Canny high threshold',
minimum=1,
maximum=255,
value=200,
step=1)
ddim_steps = gr.Slider(label='Steps',
minimum=1,
maximum=20,
value=20,
step=1,
info='To avoid overload, maximum 20')
scale = gr.Slider(label='CFG scale',
minimum=0.1,
maximum=30.0,
value=7.5,
step=0.1)
sd_model_list = list(model_dict.keys())
sd_model = gr.Dropdown(sd_model_list,
label='Base model',
value='Stable Diffusion 1.5')
a_prompt = gr.Textbox(label='Added prompt',
value='best quality, extremely detailed')
n_prompt = gr.Textbox(
label='Negative prompt',
value=('longbody, lowres, bad anatomy, bad hands, '
'missing fingers, extra digit, fewer digits, '
'cropped, worst quality, low quality'))
with gr.Accordion('Advanced options for the key fame translation',
open=False):
interval = gr.Slider(
label='Key frame frequency (K)',
minimum=1,
maximum=1,
value=1,
step=1,
info='Uniformly sample the key frames every K frames')
keyframe_count = gr.Slider(
label='Number of key frames',
minimum=1,
maximum=1,
value=1,
step=1,
info='To avoid overload, maximum 8 key frames')
use_constraints = gr.CheckboxGroup(
[
'shape-aware fusion', 'pixel-aware fusion',
'color-aware AdaIN'
],
label='Select the cross-frame contraints to be used',
value=[
'shape-aware fusion', 'pixel-aware fusion',
'color-aware AdaIN'
]),
with gr.Row():
cross_start = gr.Slider(
label='Cross-frame attention start',
minimum=0,
maximum=1,
value=0,
step=0.05)
cross_end = gr.Slider(label='Cross-frame attention end',
minimum=0,
maximum=1,
value=1,
step=0.05)
style_update_freq = gr.Slider(
label='Cross-frame attention update frequency',
minimum=1,
maximum=100,
value=1,
step=1,
info=
('Update the key and value for '
'cross-frame attention every N key frames (recommend N*K>=10)'
))
with gr.Row():
warp_start = gr.Slider(label='Shape-aware fusion start',
minimum=0,
maximum=1,
value=0,
step=0.05)
warp_end = gr.Slider(label='Shape-aware fusion end',
minimum=0,
maximum=1,
value=0.1,
step=0.05)
with gr.Row():
mask_start = gr.Slider(label='Pixel-aware fusion start',
minimum=0,
maximum=1,
value=0.5,
step=0.05)
mask_end = gr.Slider(label='Pixel-aware fusion end',
minimum=0,
maximum=1,
value=0.8,
step=0.05)
with gr.Row():
ada_start = gr.Slider(label='Color-aware AdaIN start',
minimum=0,
maximum=1,
value=0.8,
step=0.05)
ada_end = gr.Slider(label='Color-aware AdaIN end',
minimum=0,
maximum=1,
value=1,
step=0.05)
mask_strength = gr.Slider(label='Pixel-aware fusion stength',
minimum=0,
maximum=1,
value=0.5,
step=0.01)
inner_strength = gr.Slider(
label='Pixel-aware fusion detail level',
minimum=0.5,
maximum=1,
value=0.9,
step=0.01,
info='Use a low value to prevent artifacts')
smooth_boundary = gr.Checkbox(
label='Smooth fusion boundary',
value=True,
info='Select to prevent artifacts at boundary')
with gr.Accordion('Example configs', open=True):
config_dir = 'config'
config_list = os.listdir(config_dir)
args_list = []
for config in config_list:
try:
config_path = os.path.join(config_dir, config)
args = cfg_to_input(config_path)
args_list.append(args)
except FileNotFoundError:
# The video file does not exist, skipped
pass
ips = [
prompt, image_resolution, control_strength, color_preserve,
left_crop, right_crop, top_crop, bottom_crop, control_type,
low_threshold, high_threshold, ddim_steps, scale, seed,
sd_model, a_prompt, n_prompt, interval, keyframe_count,
x0_strength, use_constraints[0], cross_start, cross_end,
style_update_freq, warp_start, warp_end, mask_start,
mask_end, ada_start, ada_end, mask_strength,
inner_strength, smooth_boundary
]
with gr.Column():
result_image = gr.Image(label='Output first frame',
type='numpy',
interactive=False)
result_keyframe = gr.Video(label='Output key frame video',
format='mp4',
interactive=False)
with gr.Row():
gr.Examples(examples=args_list,
inputs=[input_path, *ips],
fn=process0,
outputs=[result_image, result_keyframe],
cache_examples=True)
def input_uploaded(path):
frame_count = get_frame_count(path)
if frame_count <= 2:
raise gr.Error('The input video is too short!'
'Please input another video.')
default_interval = min(10, frame_count - 2)
max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME)
global video_frame_count
video_frame_count = frame_count
global global_video_path
global_video_path = path
return gr.Slider.update(value=default_interval,
maximum=MAX_KEYFRAME), gr.Slider.update(
value=max_keyframe, maximum=max_keyframe)
def input_changed(path):
frame_count = get_frame_count(path)
if frame_count <= 2:
return gr.Slider.update(maximum=1), gr.Slider.update(maximum=1)
default_interval = min(10, frame_count - 2)
max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME)
global video_frame_count
video_frame_count = frame_count
global global_video_path
global_video_path = path
return gr.Slider.update(maximum=max_keyframe), \
gr.Slider.update(maximum=max_keyframe)
def interval_changed(interval):
global video_frame_count
if video_frame_count is None:
return gr.Slider.update()
max_keyframe = (video_frame_count - 2) // interval
return gr.Slider.update(value=max_keyframe, maximum=max_keyframe)
input_path.change(input_changed, input_path, [interval, keyframe_count])
input_path.upload(input_uploaded, input_path, [interval, keyframe_count])
interval.change(interval_changed, interval, keyframe_count)
run_button.click(fn=process,
inputs=ips,
outputs=[result_image, result_keyframe])
run_button1.click(fn=process1, inputs=ips, outputs=[result_image])
run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe])
def process3():
raise gr.Error(
"Coming Soon. Full code for full video translation will be "
"released upon the publication of the paper.")
run_button3.click(fn=process3, outputs=[result_keyframe])
block.queue(concurrency_count=1, max_size=20)
block.launch(server_name='0.0.0.0')