import gradio as gr from vid2persona import init from vid2persona.pipeline import vlm from vid2persona.pipeline import llm init.init_model("HuggingFaceH4/zephyr-7b-beta") init.auth_gcp() init.get_env_vars() prompt_tpl_path = "vid2persona/prompts" async def extract_traits(video_path): traits = await vlm.get_traits( init.gcp_project_id, init.gcp_project_location, video_path, prompt_tpl_path ) if 'characters' in traits: traits = traits['characters'][0] return [ traits, [], gr.Textbox("", interactive=True), gr.Button(interactive=True), gr.Button(interactive=True), gr.Button(interactive=True) ] async def conversation( message: str, messages: list, traits: dict, model_id: str, max_input_token_length: int, max_new_tokens: int, temperature: float, top_p: float, top_k: float, repetition_penalty: float, ): messages = messages + [[message, ""]] yield [messages, message, gr.Button(interactive=False), gr.Button(interactive=False)] async for partial_response in llm.chat( message, messages, traits, prompt_tpl_path, model_id, max_input_token_length, max_new_tokens, temperature, top_p, top_k, repetition_penalty, hf_token=None#init.hf_access_token ): last_message = messages[-1] last_message[1] = last_message[1] + partial_response messages[-1] = last_message yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)] yield [messages, "", gr.Button(interactive=True), gr.Button(interactive=True)] async def regen_conversation( messages: list, traits: dict, model_id: str, max_input_token_length: int, max_new_tokens: int, temperature: float, top_p: float, top_k: float, repetition_penalty: float, ): if len(messages) > 0: message = messages[-1][0] messages = messages[:-1] messages = messages + [[message, ""]] yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)] async for partial_response in llm.chat( message, messages, traits, prompt_tpl_path, model_id, max_input_token_length, max_new_tokens, temperature, top_p, top_k, repetition_penalty, hf_token=None#init.hf_access_token ): last_message = messages[-1] last_message[1] = last_message[1] + partial_response messages[-1] = last_message yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)] yield [messages, "", gr.Button(interactive=True), gr.Button(interactive=True)] with gr.Blocks(css="styles.css", theme=gr.themes.Soft()) as demo: gr.Markdown("Vid2Persona", elem_classes=["md-center", "h1-font"]) gr.Markdown("This project breathes life into video characters by using AI to describe their personality and then chat with you as them.") with gr.Column(elem_classes=["group"]): with gr.Row(): video = gr.Video(label="upload short video clip") traits = gr.Json(label="extracted traits") with gr.Row(): trait_gen = gr.Button("generate traits") with gr.Column(elem_classes=["group"]): chatbot = gr.Chatbot([], label="chatbot", elem_id="chatbot", elem_classes=["chatbot-no-label"]) with gr.Row(): clear = gr.Button("clear conversation", interactive=False) regen = gr.Button("regenerate the last", interactive=False) stop = gr.Button("stop", interactive=False) user_input = gr.Textbox(placeholder="ask anything", interactive=False, elem_classes=["textbox-no-label", "textbox-no-top-bottom-borders"]) with gr.Accordion("parameters' control pane", open=False): model_id = gr.Dropdown(choices=init.ALLOWED_LLM_FOR_HF_PRO_ACCOUNTS, value="HuggingFaceH4/zephyr-7b-beta", label="Model ID", visible=False) with gr.Row(): max_input_token_length = gr.Slider(minimum=1024, maximum=4096, value=4096, label="max-input-tokens") max_new_tokens = gr.Slider(minimum=128, maximum=2048, value=256, label="max-new-tokens") with gr.Row(): temperature = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.6, label="temperature") top_p = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.9, label="top-p") top_k = gr.Slider(minimum=0, maximum=2, step=0.1, value=50, label="top-k") repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.1, value=1.2, label="repetition-penalty") with gr.Row(): gr.Markdown( "[![GitHub Repo](https://img.shields.io/badge/GitHub%20Repo-gray?style=for-the-badge&logo=github&link=https://github.com/deep-diver/Vid2Persona)](https://github.com/deep-diver/Vid2Persona) " "[![Chansung](https://img.shields.io/badge/Chansung-blue?style=for-the-badge&logo=twitter&link=https://twitter.com/algo_diver)](https://twitter.com/algo_diver) " "[![Sayak](https://img.shields.io/badge/Sayak-blue?style=for-the-badge&logo=twitter&link=https://twitter.com/RisingSayak)](https://twitter.com/RisingSayak )", elem_id="bottom-md" ) trait_gen.click( extract_traits, [video], [traits, chatbot, user_input, clear, regen, stop] ) conv = user_input.submit( conversation, [ user_input, chatbot, traits, model_id, max_input_token_length, max_new_tokens, temperature, top_p, top_k, repetition_penalty, ], [chatbot, user_input, clear, regen] ) clear.click( lambda: [ gr.Chatbot([]), gr.Button(interactive=False), gr.Button(interactive=False), ], None, [chatbot, clear, regen] ) conv_regen = regen.click( regen_conversation, [ chatbot, traits, model_id, max_input_token_length, max_new_tokens, temperature, top_p, top_k, repetition_penalty, ], [chatbot, user_input, clear, regen] ) stop.click( lambda: [ gr.Button(interactive=True), gr.Button(interactive=True), gr.Button(interactive=True), ], None, [clear, regen, stop], cancels=[conv, conv_regen] ) gr.Examples( [["assets/sample1.mp4"]],#, ["assets/sample2.mp4"], ["assets/sample3.mp4"], ["assets/sample4.mp4"]], video, [traits, chatbot, user_input, clear, regen, stop], extract_traits, cache_examples=True ) demo.launch()