poor-mans-duplex / duplex.py
versae's picture
Update duplex.py
11abfd0
raw
history blame
No virus
7.4 kB
import os
import json
import random
import string
import gradio as gr
import requests
from transformers import pipeline, set_seed
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
import sys
import gradio as gr
from transformers import pipeline, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
DEBUG = os.environ.get("DEBUG", "false")[0] in "ty1"
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None)
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
HEADER = """
# Poor Man's Duplex
""".strip()
FOOTER = """
<div align=center>
<img src="https://visitor-badge.glitch.me/badge?page_id=versae/poor-mans-duplex"/>
<div align=center>
""".strip()
asr_model_name_es = "jonatasgrosman/wav2vec2-large-xlsr-53-spanish"
model_instance_es = AutoModelForCTC.from_pretrained(asr_model_name_es)
processor_es = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model_name_es)
asr_es = pipeline(
"automatic-speech-recognition",
model=model_instance_es,
tokenizer=processor_es.tokenizer,
feature_extractor=processor_es.feature_extractor,
decoder=processor_es.decoder
)
tts_model_name = "facebook/tts_transformer-es-css10"
speak_es = gr.Interface.load(f"huggingface/{tts_model_name}")
transcribe_es = lambda input_file: asr_es(input_file, chunk_length_s=5, stride_length_s=1)["text"]
def generate_es(text, **kwargs):
# max_length=100, top_k=100, top_p=50, temperature=0.95, do_sample=True, do_clean=True
api_uri = "https://hf.space/embed/bertin-project/bertin-gpt-j-6B/+/api/predict/"
response = requests.post(api_uri, data=json.dumps({"data": [text, 100, 100, 50, 0.95, True, True]}))
if response.ok:
print(response.json())
return response.json()["data"][0]
else:
return ""
asr_model_name_en = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
model_instance_en = AutoModelForCTC.from_pretrained(asr_model_name_en)
processor_en = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model_name_en)
asr_en = pipeline(
"automatic-speech-recognition",
model=model_instance_en,
tokenizer=processor_en.tokenizer,
feature_extractor=processor_en.feature_extractor,
decoder=processor_en.decoder
)
tts_model_name = "facebook/fastspeech2-en-ljspeech"
speak_en = gr.Interface.load(f"huggingface/{tts_model_name}")
transcribe_en = lambda input_file: asr_en(input_file, chunk_length_s=5, stride_length_s=1)["text"]
generate_iface = gr.Interface.load("huggingface/EleutherAI/gpt-j-6B")
def generate_en(text, **kwargs):
response = generate_iface(text)
print(response)
return response or ""
def select_lang(lang):
if lang.lower() == "spanish":
return generate_es, transcribe_es, speak_es
else:
return generate_en, transcribe_en, speak_en
def select_lang_vars(lang):
if lang.lower() == "spanish":
AGENT = "BERTIN"
USER = "ENTREVISTADOR"
CONTEXT = """La siguiente conversación es un extracto de una entrevista a {AGENT} celebrada en Madrid para Radio Televisión Española:
{USER}: Bienvenido, {AGENT}. Un placer tenerlo hoy con nosotros.
{AGENT}: Gracias. El placer es mío."""
else:
AGENT = "ELEUTHER"
USER = "INTERVIEWER"
CONTEXT = """The next conversation is an excerpt from an interview to {AGENT} that appeared in the New York Times:
{USER}: Welcome, {AGENT}. It is a pleasure to have you here today.
{AGENT}: Thanks. The pleasure is mine."""
return AGENT, USER, CONTEXT
def chat_with_gpt(lang, agent, user, context, audio_in, history):
generate, transcribe, speak = select_lang(lang)
AGENT, USER, _ = select_lang_vars(lang)
user_message = transcribe(audio_in)
# agent = AGENT
# user = USER
generation_kwargs = {
"max_length": 25,
# "top_k": top_k,
# "top_p": top_p,
# "temperature": temperature,
# "do_sample": do_sample,
# "do_clean": do_clean,
# "num_return_sequences": 1,
# "return_full_text": False,
}
message = user_message.split(" ", 1)[0].capitalize() + " " + user_message.split(" ", 1)[-1]
history = history or [] #[(f"{user}: Bienvenido. Encantado de tenerle con nosotros.", f"{agent}: Un placer, muchas gracias por la invitación.")]
context = context.format(USER=user or USER, AGENT=agent or AGENT).strip()
if context[-1] not in ".:":
context += "."
context_length = len(context.split())
history_take = 0
history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:])
while len(history_context.split()) > MAX_LENGTH - (generation_kwargs["max_length"] + context_length):
history_take += 1
history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:])
if history_take >= MAX_LENGTH:
break
context += history_context
for _ in range(5):
response = generate(f"{context}\n\n{user}: {message}.\n", **generation_kwargs)
if DEBUG:
print("\n-----" + response + "-----\n")
response = response.split("\n")[-1]
if agent in response and response.split(agent)[-1]:
response = response.split(agent)[-1]
if user in response and response.split(user)[-1]:
response = response.split(user)[-1]
if response and response[0] in string.punctuation:
response = response[1:].strip()
if response.strip().startswith(f"{user}: {message}"):
response = response.strip().split(f"{user}: {message}")[-1]
if response.replace(".", "").strip() and message.replace(".", "").strip() != response.replace(".", "").strip():
break
if DEBUG:
print()
print("CONTEXT:")
print(context)
print()
print("MESSAGE")
print(message)
print()
print("RESPONSE:")
print(response)
if not response.strip():
response = "Lo siento, no puedo hablar ahora" if lang.lower() == "Spanish" else "Sorry, can't talk right now"
history.append((user_message, response))
return history, history, speak(response)
with gr.Blocks() as demo:
gr.Markdown(HEADER)
lang = gr.Radio(label="Language", choices=["English", "Spanish"], default="English", type="value")
AGENT, USER, CONTEXT = select_lang_vars("English")
context = gr.Textbox(label="Context", lines=5, value=CONTEXT)
with gr.Row():
audio_in = gr.Audio(label="User", source="microphone", type="filepath")
audio_out = gr.Audio(label="Agent", interactive=False)
# chat_btn = gr.Button("Submit")
with gr.Row():
user = gr.Textbox(label="User", value=USER)
agent = gr.Textbox(label="Agent", value=AGENT)
lang.change(select_lang_vars, inputs=[lang], outputs=[agent, user, context])
history = gr.Variable(value=[])
chatbot = gr.Variable() # gr.Chatbot(color_map=("green", "gray"), visible=False)
# chat_btn.click(chat_with_gpt, inputs=[lang, agent, user, context, audio_in, history], outputs=[chatbot, history, audio_out])
audio_in.change(chat_with_gpt, inputs=[lang, agent, user, context, audio_in, history], outputs=[chatbot, history, audio_out])
gr.Markdown(FOOTER)
demo.launch()