#!/usr/bin/env python from __future__ import annotations import os import gradio as gr # from inference import InferencePipeline # from FateZero import test_fatezero from inference_fatezero import merge_config_then_run # class InferenceUtil: # def __init__(self, hf_token: str | None): # self.hf_token = hf_token # def load_model_info(self, model_id: str) -> tuple[str, str]: # # todo FIXME # try: # card = InferencePipeline.get_model_card(model_id, self.hf_token) # except Exception: # return '', '' # base_model = getattr(card.data, 'base_model', '') # training_prompt = getattr(card.data, 'training_prompt', '') # return base_model, training_prompt TITLE = '# [FateZero](http://fate-zero-edit.github.io/)' HF_TOKEN = os.getenv('HF_TOKEN') # pipe = InferencePipeline(HF_TOKEN) pipe = merge_config_then_run # app = InferenceUtil(HF_TOKEN) with gr.Blocks(css='style.css') as demo: gr.Markdown(TITLE) with gr.Row(): with gr.Column(): with gr.Box(): model_id = gr.Dropdown( label='Model ID', choices=[ 'CompVis/stable-diffusion-v1-4', # add shape editing ckpt here ], value='CompVis/stable-diffusion-v1-4') # with gr.Accordion( # label= # 'Model info (Base model and prompt used for training)', # open=False): # with gr.Row(): # base_model_used_for_training = gr.Text( # label='Base model', interactive=False) # prompt_used_for_training = gr.Text( # label='Training prompt', interactive=False) data_path = gr.Dropdown( label='data path', choices=[ 'FateZero/data/teaser_car-turn', 'FateZero/data/style/sunflower', # add shape editing ckpt here ], value='FateZero/data/teaser_car-turn') source_prompt = gr.Textbox(label='Source Prompt', max_lines=1, placeholder='Example: "a silver jeep driving down a curvy road in the countryside"') target_prompt = gr.Textbox(label='Target Prompt', max_lines=1, placeholder='Example: "watercolor painting of a silver jeep driving down a curvy road in the countryside"') cross_replace_steps = gr.Slider(label='cross-attention replace steps', minimum=0.0, maximum=1.0, step=0.1, value=0.7) self_replace_steps = gr.Slider(label='self-attention replace steps', minimum=0.0, maximum=1.0, step=0.1, value=0.7) enhance_words = gr.Textbox(label='words to be enhanced', max_lines=1, placeholder='Example: "watercolor "') enhance_words_value = gr.Slider(label='Amplify the target cross-attention', minimum=0.0, maximum=20.0, step=1, value=10) with gr.Accordion('DDIM Parameters', open=False): num_steps = gr.Slider(label='Number of Steps', minimum=0, maximum=100, step=1, value=50) guidance_scale = gr.Slider(label='CFG Scale', minimum=0, maximum=50, step=0.1, value=7.5) run_button = gr.Button('Generate') # gr.Markdown(''' # - It takes a few minutes to download model first. # - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100) # ''') gr.Markdown(''' todo ''') with gr.Column(): result = gr.Video(label='Result') with gr.Row(): examples = [ [ 'CompVis/stable-diffusion-v1-4', 'FateZero/data/teaser_car-turn', 'a silver jeep driving down a curvy road in the countryside', 'watercolor painting of a silver jeep driving down a curvy road in the countryside', 0.8, 0.8, "watercolor", 10, 10, 7.5, ], [ 'CompVis/stable-diffusion-v1-4', 'FateZero/data/style/sunflower', 'a yellow sunflower', 'van gogh style painting of a yellow sunflower', 0.5, 0.5, 'van gogh', 10, 50, 7.5, ], ] gr.Examples(examples=examples, inputs=[ model_id, data_path, source_prompt, target_prompt, cross_replace_steps, self_replace_steps, enhance_words, enhance_words_value, num_steps, guidance_scale, ], outputs=result, fn=merge_config_then_run, cache_examples=os.getenv('SYSTEM') == 'spaces') # model_id.change(fn=app.load_model_info, # inputs=model_id, # outputs=[ # base_model_used_for_training, # prompt_used_for_training, # ]) inputs = [ model_id, data_path, source_prompt, target_prompt, cross_replace_steps, self_replace_steps, enhance_words, enhance_words_value, num_steps, guidance_scale, ] # prompt.submit(fn=pipe.run, inputs=inputs, outputs=result) target_prompt.submit(fn=merge_config_then_run, inputs=inputs, outputs=result) # run_button.click(fn=pipe.run, inputs=inputs, outputs=result) run_button.click(fn=merge_config_then_run, inputs=inputs, outputs=result) demo.queue().launch()