#!/usr/bin/env python from __future__ import annotations import os import gradio as gr # from inference import InferencePipeline # from FateZero import test_fatezero from inference_fatezero import merge_config_then_run # class InferenceUtil: # def __init__(self, hf_token: str | None): # self.hf_token = hf_token # def load_model_info(self, model_id: str) -> tuple[str, str]: # # todo FIXME # try: # card = InferencePipeline.get_model_card(model_id, self.hf_token) # except Exception: # return '', '' # base_model = getattr(card.data, 'base_model', '') # training_prompt = getattr(card.data, 'training_prompt', '') # return base_model, training_prompt TITLE = '# [FateZero](http://fate-zero-edit.github.io/)' HF_TOKEN = os.getenv('HF_TOKEN') # pipe = InferencePipeline(HF_TOKEN) pipe = merge_config_then_run # app = InferenceUtil(HF_TOKEN) with gr.Blocks(css='style.css') as demo: gr.Markdown(TITLE) with gr.Row(): with gr.Column(): with gr.Accordion('Input Video', open=True): user_input_video = gr.File(label='Input Source Video') with gr.Accordion('Temporal Crop offset and Sampling Stride', open=False): n_sample_frame = gr.Slider(label='Number of Frames in Video', minimum=0, maximum=32, step=1, value=8) stride = gr.Slider(label='Temporal sampling stride in Video', minimum=0, maximum=20, step=1, value=1) start_sample_frame = gr.Number(label='Start frame in the video', value=0, precision=0) with gr.Accordion('Spatial Crop offset', open=False): left_crop = gr.Number(label='Left crop', value=0, precision=0) right_crop = gr.Number(label='Right crop', value=0, precision=0) top_crop = gr.Number(label='Top crop', value=0, precision=0) bottom_crop = gr.Number(label='Bottom crop', value=0, precision=0) offset_list = [ left_crop, right_crop, top_crop, bottom_crop, ] ImageSequenceDataset_list = [ start_sample_frame, n_sample_frame, stride ] + offset_list data_path = gr.Dropdown( label='provided data path', choices=[ 'FateZero/data/teaser_car-turn', 'FateZero/data/style/sunflower', # add shape editing ckpt here ], value='FateZero/data/teaser_car-turn') model_id = gr.Dropdown( label='Model ID', choices=[ 'CompVis/stable-diffusion-v1-4', # add shape editing ckpt here ], value='CompVis/stable-diffusion-v1-4') # with gr.Accordion( # label= # 'Model info (Base model and prompt used for training)', # open=False): # with gr.Row(): # base_model_used_for_training = gr.Text( # label='Base model', interactive=False) # prompt_used_for_training = gr.Text( # label='Training prompt', interactive=False) with gr.Accordion('Text Prompt', open=True): source_prompt = gr.Textbox(label='Source Prompt', info='A good prompt describes each frame and most objects in video. Especially, it has the object or attribute that we want to edit or preserve.', max_lines=1, placeholder='Example: "a silver jeep driving down a curvy road in the countryside"', value='a silver jeep driving down a curvy road in the countryside') target_prompt = gr.Textbox(label='Target Prompt', info='A reasonable composition of video may achieve better results(e.g., "sunflower" video with "Van Gogh" prompt is better than "sunflower" with "Monet")', max_lines=1, placeholder='Example: "watercolor painting of a silver jeep driving down a curvy road in the countryside"', value='watercolor painting of a silver jeep driving down a curvy road in the countryside') with gr.Accordion('DDIM Parameters', open=True): num_steps = gr.Slider(label='Number of Steps', info='larger value has better editing capacity, but takes more time and memory', minimum=0, maximum=50, step=1, value=10) guidance_scale = gr.Slider(label='CFG Scale', minimum=0, maximum=50, step=0.1, value=7.5) run_button = gr.Button('Generate') # gr.Markdown(''' # - It takes a few minutes to download model first. # - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100) # ''') gr.Markdown(''' todo ''') with gr.Column(): result = gr.Video(label='Result') result.style(height=512, width=512) with gr.Accordion('FateZero Parameters for attention fusing', open=True): cross_replace_steps = gr.Slider(label='cross-attention replace steps', info='More steps, replace more cross attention to preserve semantic layout.', minimum=0.0, maximum=1.0, step=0.1, value=0.7) self_replace_steps = gr.Slider(label='self-attention replace steps', info='More steps, replace more spatial-temporal self-attention to preserve geometry and motion.', minimum=0.0, maximum=1.0, step=0.1, value=0.7) enhance_words = gr.Textbox(label='words to be enhanced', info='Amplify the target-words cross attention', max_lines=1, placeholder='Example: "watercolor "', value='watercolor') enhance_words_value = gr.Slider(label='Amplify the target cross-attention', info='larger value, more elements of target words', minimum=0.0, maximum=20.0, step=1, value=10) with gr.Row(): examples = [ [ 'CompVis/stable-diffusion-v1-4', 'FateZero/data/teaser_car-turn', 'a silver jeep driving down a curvy road in the countryside', 'watercolor painting of a silver jeep driving down a curvy road in the countryside', 0.8, 0.8, "watercolor", 10, 10, 7.5, ], [ 'CompVis/stable-diffusion-v1-4', 'FateZero/data/style/sunflower', 'a yellow sunflower', 'van gogh style painting of a yellow sunflower', 0.5, 0.5, 'van gogh', 10, 10, 7.5, ], ] gr.Examples(examples=examples, inputs=[ model_id, data_path, source_prompt, target_prompt, cross_replace_steps, self_replace_steps, enhance_words, enhance_words_value, num_steps, guidance_scale, ], outputs=result, fn=merge_config_then_run, cache_examples=os.getenv('SYSTEM') == 'spaces') # model_id.change(fn=app.load_model_info, # inputs=model_id, # outputs=[ # base_model_used_for_training, # prompt_used_for_training, # ]) inputs = [ model_id, data_path, source_prompt, target_prompt, cross_replace_steps, self_replace_steps, enhance_words, enhance_words_value, num_steps, guidance_scale, user_input_video, *ImageSequenceDataset_list ] # prompt.submit(fn=pipe.run, inputs=inputs, outputs=result) target_prompt.submit(fn=merge_config_then_run, inputs=inputs, outputs=result) # run_button.click(fn=pipe.run, inputs=inputs, outputs=result) run_button.click(fn=merge_config_then_run, inputs=inputs, outputs=result) demo.queue().launch()