Rally_ChatBot / app.py
Mikunono's picture
Update app.py
e9fc801 verified
raw
history blame
4.95 kB
import gradio as gr
from transformers import pipeline
import librosa
########################ASR model###############################
from transformers import WhisperProcessor, WhisperForConditionalGeneration
# load model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-base")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
model.config.forced_decoder_ids = None
sample_rate = 16000
def ASR_model(audio, sr=16000):
DB_audio = audio
input_features = processor(audio, sampling_rate=sr, return_tensors="pt").input_features
# generate token ids
predicted_ids = model.generate(input_features)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
return transcription
########################LLama model###############################
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name_or_path = "TheBloke/llama2_7b_chat_uncensored-GPTQ"
# To use a different branch, change revision
# For example: revision="main"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
device_map="auto",
trust_remote_code=True,
revision="main",
quantization_config=QuantizationConfig(disable_exllama=True)
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
Llama_pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=20,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=40,
repetition_penalty=1.1
)
history="""User: Hello, Rally?
Rally: I'm happy to see you again. What you want to talk to day?
User: Let's talk about food
Rally: Sure.
User: I'm hungry right now. Do you know any Vietnamese food?"""
prompt_template = f"""<|im_start|>system
Talk one sentence to continue the conversation<|im_end|>
{history}
Rally:"""
print(Llama_pipe(prompt_template)[0]['generated_text'])
def RallyRespone(chat_history, message):
chat_history += "User: " + message + "\n"
t_chat = Llama_pipe(prompt_template)[0]['generated_text']
res = t_chat[t_chat.rfind("Rally: "):]
return res
########################Gradio UI###############################
# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.
def add_file(files):
return files.name
def print_like_dislike(x: gr.LikeData):
print(x.index, x.value, x.liked)
def upfile(files):
x = librosa.load(files, sr=16000)
print(x[0])
text = ASR_model(x[0])
return [text[0], text[0]]
def transcribe(audio):
sr, y = audio
y = y.astype(np.float32)
y /= np.max(np.abs(y))
return transcriber({"sampling_rate": sr, "raw": y})["text"], transcriber({"sampling_rate": sr, "raw": y})["text"]
# def recommand(text):
# ret = "answer for"
# return ret + text
def add_text(history, text):
history = history + [(text, None)]
return history, gr.Textbox(value="", interactive=False)
# def bot(history):
# response = "**That's cool!**"
# history[-1][1] = ""
# for character in response:
# history[-1][1] += character
# time.sleep(0.05)
# yield history
with gr.Blocks() as demo:
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False,
)
file_output = gr.File()
def respond(message, chat_history):
bot_message = RallyRespone(chat_history, message)
chat_history.append((message, bot_message))
time.sleep(2)
print (chat_history[-1])
return chat_history[-1][-1], chat_history
with gr.Row():
with gr.Column():
audio_speech = gr.Audio(sources=["microphone"])
submit = gr.Button("Submit")
send = gr.Button("Send")
btn = gr.UploadButton("📁", file_types=["audio"])
with gr.Column():
opt1 = gr.Button("1: ")
opt2 = gr.Button("2: ")
#submit.click(translate, inputs=audio_speech, outputs=[opt1, opt2])
# output is opt1 value, opt2 value [ , ]
file_msg = btn.upload(add_file, btn, file_output)
submit.click(upfile, inputs=file_output, outputs=[opt1, opt2])
send.click(transcribe, inputs=audio_speech, outputs=[opt1, opt2])
opt1.click(respond, [opt1, chatbot], [opt1, chatbot])
opt2.click(respond, [opt2, chatbot], [opt2, chatbot])
#opt2.click(recommand, inputs=opt2)
#click event maybe BOT . generate history = optx.value,
chatbot.like(print_like_dislike, None, None)
if __name__ == "__main__":
demo.queue()
demo.launch(debug=True)