FateZero / app_fatezero.py
chenyangqi's picture
initialia commit based on tune-a-video
b57c333
raw history blame
No virus
7.01 kB
#!/usr/bin/env python
from __future__ import annotations
import os
import gradio as gr
# from inference import InferencePipeline
# from FateZero import test_fatezero
from inference_fatezero import merge_config_then_run
# class InferenceUtil:
# def __init__(self, hf_token: str | None):
# self.hf_token = hf_token
# def load_model_info(self, model_id: str) -> tuple[str, str]:
# # todo FIXME
# try:
# card = InferencePipeline.get_model_card(model_id, self.hf_token)
# except Exception:
# return '', ''
# base_model = getattr(card.data, 'base_model', '')
# training_prompt = getattr(card.data, 'training_prompt', '')
# return base_model, training_prompt
TITLE = '# [FateZero](http://fate-zero-edit.github.io/)'
HF_TOKEN = os.getenv('HF_TOKEN')
# pipe = InferencePipeline(HF_TOKEN)
pipe = merge_config_then_run
# app = InferenceUtil(HF_TOKEN)
with gr.Blocks(css='style.css') as demo:
gr.Markdown(TITLE)
with gr.Row():
with gr.Column():
with gr.Box():
model_id = gr.Dropdown(
label='Model ID',
choices=[
'CompVis/stable-diffusion-v1-4',
# add shape editing ckpt here
],
value='CompVis/stable-diffusion-v1-4')
# with gr.Accordion(
# label=
# 'Model info (Base model and prompt used for training)',
# open=False):
# with gr.Row():
# base_model_used_for_training = gr.Text(
# label='Base model', interactive=False)
# prompt_used_for_training = gr.Text(
# label='Training prompt', interactive=False)
data_path = gr.Dropdown(
label='data path',
choices=[
'FateZero/data/teaser_car-turn',
'FateZero/data/style/sunflower',
# add shape editing ckpt here
],
value='FateZero/data/teaser_car-turn')
source_prompt = gr.Textbox(label='Source Prompt',
max_lines=1,
placeholder='Example: "a silver jeep driving down a curvy road in the countryside"')
target_prompt = gr.Textbox(label='Target Prompt',
max_lines=1,
placeholder='Example: "watercolor painting of a silver jeep driving down a curvy road in the countryside"')
cross_replace_steps = gr.Slider(label='cross-attention replace steps',
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7)
self_replace_steps = gr.Slider(label='self-attention replace steps',
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7)
enhance_words = gr.Textbox(label='words to be enhanced',
max_lines=1,
placeholder='Example: "watercolor "')
enhance_words_value = gr.Slider(label='Amplify the target cross-attention',
minimum=0.0,
maximum=20.0,
step=1,
value=10)
with gr.Accordion('DDIM Parameters', open=False):
num_steps = gr.Slider(label='Number of Steps',
minimum=0,
maximum=100,
step=1,
value=50)
guidance_scale = gr.Slider(label='CFG Scale',
minimum=0,
maximum=50,
step=0.1,
value=7.5)
run_button = gr.Button('Generate')
# gr.Markdown('''
# - It takes a few minutes to download model first.
# - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
# ''')
gr.Markdown('''
todo
''')
with gr.Column():
result = gr.Video(label='Result')
with gr.Row():
examples = [
[
'CompVis/stable-diffusion-v1-4',
'FateZero/data/teaser_car-turn',
'a silver jeep driving down a curvy road in the countryside',
'watercolor painting of a silver jeep driving down a curvy road in the countryside',
0.8,
0.8,
"watercolor",
10,
10,
7.5,
],
[
'CompVis/stable-diffusion-v1-4',
'FateZero/data/style/sunflower',
'a yellow sunflower',
'van gogh style painting of a yellow sunflower',
0.5,
0.5,
'van gogh',
10,
50,
7.5,
],
]
gr.Examples(examples=examples,
inputs=[
model_id,
data_path,
source_prompt,
target_prompt,
cross_replace_steps,
self_replace_steps,
enhance_words,
enhance_words_value,
num_steps,
guidance_scale,
],
outputs=result,
fn=merge_config_then_run,
cache_examples=os.getenv('SYSTEM') == 'spaces')
# model_id.change(fn=app.load_model_info,
# inputs=model_id,
# outputs=[
# base_model_used_for_training,
# prompt_used_for_training,
# ])
inputs = [
model_id,
data_path,
source_prompt,
target_prompt,
cross_replace_steps,
self_replace_steps,
enhance_words,
enhance_words_value,
num_steps,
guidance_scale,
]
# prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
target_prompt.submit(fn=merge_config_then_run, inputs=inputs, outputs=result)
# run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
run_button.click(fn=merge_config_then_run, inputs=inputs, outputs=result)
demo.queue().launch()