FateZero / app_fatezero.py
chenyangqi's picture
update gradio version
24b6bf7
raw history blame
No virus
10.6 kB
#!/usr/bin/env python
from __future__ import annotations
import os
import gradio as gr
# from inference import InferencePipeline
# from FateZero import test_fatezero
from inference_fatezero import merge_config_then_run
# class InferenceUtil:
# def __init__(self, hf_token: str | None):
# self.hf_token = hf_token
# def load_model_info(self, model_id: str) -> tuple[str, str]:
# # todo FIXME
# try:
# card = InferencePipeline.get_model_card(model_id, self.hf_token)
# except Exception:
# return '', ''
# base_model = getattr(card.data, 'base_model', '')
# training_prompt = getattr(card.data, 'training_prompt', '')
# return base_model, training_prompt
TITLE = '# [FateZero](http://fate-zero-edit.github.io/)'
HF_TOKEN = os.getenv('HF_TOKEN')
# pipe = InferencePipeline(HF_TOKEN)
pipe = merge_config_then_run
# app = InferenceUtil(HF_TOKEN)
with gr.Blocks(css='style.css') as demo:
gr.Markdown(TITLE)
with gr.Row():
with gr.Column():
with gr.Accordion('Input Video', open=True):
user_input_video = gr.File(label='Input Source Video')
with gr.Accordion('Temporal Crop offset and Sampling Stride', open=False):
n_sample_frame = gr.Slider(label='Number of Frames in Video',
minimum=0,
maximum=32,
step=1,
value=8)
stride = gr.Slider(label='Temporal sampling stride in Video',
minimum=0,
maximum=20,
step=1,
value=1)
start_sample_frame = gr.Number(label='Start frame in the video',
value=0,
precision=0)
with gr.Accordion('Spatial Crop offset', open=False):
left_crop = gr.Number(label='Left crop',
value=0,
precision=0)
right_crop = gr.Number(label='Right crop',
value=0,
precision=0)
top_crop = gr.Number(label='Top crop',
value=0,
precision=0)
bottom_crop = gr.Number(label='Bottom crop',
value=0,
precision=0)
offset_list = [
left_crop,
right_crop,
top_crop,
bottom_crop,
]
ImageSequenceDataset_list = [
start_sample_frame,
n_sample_frame,
stride
] + offset_list
data_path = gr.Dropdown(
label='provided data path',
choices=[
'FateZero/data/teaser_car-turn',
'FateZero/data/style/sunflower',
# add shape editing ckpt here
],
value='FateZero/data/teaser_car-turn')
model_id = gr.Dropdown(
label='Model ID',
choices=[
'CompVis/stable-diffusion-v1-4',
# add shape editing ckpt here
],
value='CompVis/stable-diffusion-v1-4')
# with gr.Accordion(
# label=
# 'Model info (Base model and prompt used for training)',
# open=False):
# with gr.Row():
# base_model_used_for_training = gr.Text(
# label='Base model', interactive=False)
# prompt_used_for_training = gr.Text(
# label='Training prompt', interactive=False)
with gr.Accordion('Text Prompt', open=True):
source_prompt = gr.Textbox(label='Source Prompt',
info='A good prompt describes each frame and most objects in video. Especially, it has the object or attribute that we want to edit or preserve.',
max_lines=1,
placeholder='Example: "a silver jeep driving down a curvy road in the countryside"',
value='a silver jeep driving down a curvy road in the countryside')
target_prompt = gr.Textbox(label='Target Prompt',
info='A reasonable composition of video may achieve better results(e.g., "sunflower" video with "Van Gogh" prompt is better than "sunflower" with "Monet")',
max_lines=1,
placeholder='Example: "watercolor painting of a silver jeep driving down a curvy road in the countryside"',
value='watercolor painting of a silver jeep driving down a curvy road in the countryside')
with gr.Accordion('DDIM Parameters', open=True):
num_steps = gr.Slider(label='Number of Steps',
info='larger value has better editing capacity, but takes more time and memory',
minimum=0,
maximum=50,
step=1,
value=10)
guidance_scale = gr.Slider(label='CFG Scale',
minimum=0,
maximum=50,
step=0.1,
value=7.5)
run_button = gr.Button('Generate')
# gr.Markdown('''
# - It takes a few minutes to download model first.
# - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
# ''')
gr.Markdown('''
todo
''')
with gr.Column():
result = gr.Video(label='Result')
result.style(height=512, width=512)
with gr.Accordion('FateZero Parameters for attention fusing', open=True):
cross_replace_steps = gr.Slider(label='cross-attention replace steps',
info='More steps, replace more cross attention to preserve semantic layout.',
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7)
self_replace_steps = gr.Slider(label='self-attention replace steps',
info='More steps, replace more spatial-temporal self-attention to preserve geometry and motion.',
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7)
enhance_words = gr.Textbox(label='words to be enhanced',
info='Amplify the target-words cross attention',
max_lines=1,
placeholder='Example: "watercolor "',
value='watercolor')
enhance_words_value = gr.Slider(label='Amplify the target cross-attention',
info='larger value, more elements of target words',
minimum=0.0,
maximum=20.0,
step=1,
value=10)
with gr.Row():
examples = [
[
'CompVis/stable-diffusion-v1-4',
'FateZero/data/teaser_car-turn',
'a silver jeep driving down a curvy road in the countryside',
'watercolor painting of a silver jeep driving down a curvy road in the countryside',
0.8,
0.8,
"watercolor",
10,
10,
7.5,
],
[
'CompVis/stable-diffusion-v1-4',
'FateZero/data/style/sunflower',
'a yellow sunflower',
'van gogh style painting of a yellow sunflower',
0.5,
0.5,
'van gogh',
10,
10,
7.5,
],
]
gr.Examples(examples=examples,
inputs=[
model_id,
data_path,
source_prompt,
target_prompt,
cross_replace_steps,
self_replace_steps,
enhance_words,
enhance_words_value,
num_steps,
guidance_scale,
],
outputs=result,
fn=merge_config_then_run,
cache_examples=os.getenv('SYSTEM') == 'spaces')
# model_id.change(fn=app.load_model_info,
# inputs=model_id,
# outputs=[
# base_model_used_for_training,
# prompt_used_for_training,
# ])
inputs = [
model_id,
data_path,
source_prompt,
target_prompt,
cross_replace_steps,
self_replace_steps,
enhance_words,
enhance_words_value,
num_steps,
guidance_scale,
user_input_video,
*ImageSequenceDataset_list
]
# prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
target_prompt.submit(fn=merge_config_then_run, inputs=inputs, outputs=result)
# run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
run_button.click(fn=merge_config_then_run, inputs=inputs, outputs=result)
demo.queue().launch()