# Copyright (2023) Tsinghua University, Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gradio as gr import argparse from model import SALMONN class ff: def generate(self, wav_path, prompt, prompt_pattern, num_beams, temperature, top_p): print(f'wav_path: {wav_path}, prompt: {prompt}, temperature: {temperature}, num_beams: {num_beams}, top_p: {top_p}') return "I'm sorry, but I cannot answer that question as it is not clear what you are asking. Can you please provide more context or clarify your question?" parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--ckpt_path", type=str, default=None) parser.add_argument("--whisper_path", type=str, default=None) parser.add_argument("--beats_path", type=str, default=None) parser.add_argument("--vicuna_path", type=str, default=None) parser.add_argument("--low_resource", action='store_true', default=False) parser.add_argument("--lora_alpha", type=int, default=32) parser.add_argument("--port", default=9527) args = parser.parse_args() # model = ff() model = SALMONN( ckpt=args.ckpt_path, whisper_path=args.whisper_path, beats_path=args.beats_path, vicuna_path=args.vicuna_path, lora_alpha=args.lora_alpha, low_resource=args.low_resource ) model.to(args.device) model.eval() # gradio def gradio_reset(chat_state): chat_state = [] return (None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your wav first', interactive=False), gr.update(value="Upload & Start Chat", interactive=True), chat_state) def upload_speech(gr_speech, text_input, chat_state): if gr_speech is None: return None, None, gr.update(interactive=True), chat_state, None chat_state.append(gr_speech) return (gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state) def gradio_ask(user_message, chatbot, chat_state): if len(user_message) == 0: return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state chat_state.append(user_message) chatbot.append([user_message, None]) # return gr.update(interactive=False, placeholder='Currently only single round conversations are supported.'), chatbot, chat_state def gradio_answer(chatbot, chat_state, num_beams, temperature, top_p): llm_message = model.generate( wav_path=chat_state[0], prompt=chat_state[1], num_beams=num_beams, temperature=temperature, top_p=top_p, ) chatbot[-1][1] = llm_message[0] return chatbot, chat_state title = """