File size: 4,936 Bytes
fff06c1 d0ad885 fff06c1 d0ad885 fff06c1 d0ad885 fff06c1 d0ad885 fff06c1 d0ad885 fff06c1 d0ad885 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
#!/usr/bin/env python
from __future__ import annotations
import os
import gradio as gr
from huggingface_hub import HfApi
from constants import MODEL_LIBRARY_ORG_NAME
from inference import InferencePipeline
class InferenceUtil:
def __init__(self, hf_token: str | None):
self.hf_token = hf_token
def load_hub_model_list(self) -> dict:
api = HfApi(token=self.hf_token)
choices = [
info.modelId
for info in api.list_models(author=MODEL_LIBRARY_ORG_NAME)
]
return gr.update(choices=choices,
value=choices[0] if choices else None)
def load_model_info(self, model_id: str) -> tuple[str, str]:
try:
card = InferencePipeline.get_model_card(model_id, self.hf_token)
except Exception:
return '', ''
base_model = getattr(card.data, 'base_model', '')
training_prompt = getattr(card.data, 'training_prompt', '')
return base_model, training_prompt
def reload_model_list_and_update_model_info(self) -> tuple[dict, str, str]:
model_list_update = self.load_hub_model_list()
model_list = model_list_update['choices']
model_info = self.load_model_info(model_list[0] if model_list else '')
return model_list_update, *model_info
TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/)'
HF_TOKEN = os.getenv('HF_TOKEN')
pipe = InferencePipeline(HF_TOKEN)
app = InferenceUtil(HF_TOKEN)
with gr.Blocks(css='style.css') as demo:
gr.Markdown(TITLE)
with gr.Row():
with gr.Column():
with gr.Box():
reload_button = gr.Button('Reload Model List')
model_id = gr.Dropdown(label='Model ID',
choices=None,
value=None)
with gr.Accordion(
label=
'Model info (Base model and prompt used for training)',
open=False):
with gr.Row():
base_model_used_for_training = gr.Text(
label='Base model', interactive=False)
prompt_used_for_training = gr.Text(
label='Training prompt', interactive=False)
prompt = gr.Textbox(label='Prompt',
max_lines=1,
placeholder='Example: "A panda is surfing"')
video_length = gr.Slider(label='Video length',
minimum=4,
maximum=12,
step=1,
value=8)
fps = gr.Slider(label='FPS',
minimum=1,
maximum=12,
step=1,
value=1)
seed = gr.Slider(label='Seed',
minimum=0,
maximum=100000,
step=1,
value=0)
with gr.Accordion('Other Parameters', open=False):
num_steps = gr.Slider(label='Number of Steps',
minimum=0,
maximum=100,
step=1,
value=50)
guidance_scale = gr.Slider(label='CFG Scale',
minimum=0,
maximum=50,
step=0.1,
value=7.5)
run_button = gr.Button('Generate')
gr.Markdown('''
- It takes a few minutes to download model first.
- Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
''')
with gr.Column():
result = gr.Video(label='Result')
reload_button.click(fn=app.reload_model_list_and_update_model_info,
inputs=None,
outputs=[
model_id,
base_model_used_for_training,
prompt_used_for_training,
])
model_id.change(fn=app.load_model_info,
inputs=model_id,
outputs=[
base_model_used_for_training,
prompt_used_for_training,
])
inputs = [
model_id,
prompt,
video_length,
fps,
seed,
num_steps,
guidance_scale,
]
prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
demo.queue().launch()
|