import os import json import random import string import numpy as np import gradio as gr import requests import soundfile as sf from transformers import pipeline, set_seed from transformers import AutoTokenizer, AutoModelForCausalLM import logging import sys import gradio as gr from transformers import pipeline, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM DEBUG = os.environ.get("DEBUG", "false")[0] in "ty1" MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024)) DEFAULT_LANG = os.environ.get("DEFAULT_LANG", "English") HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None) HEADER = """ # Poor Man's Duplex Talk to a language model like you talk on a Walkie-Talkie! Well, with larger latencies. The models are [EleutherAI's GPT-J-6B](https://huggingface.co/EleutherAI/gpt-j-6B) for English, and [BERTIN GPT-J-6B](https://huggingface.co/bertin-project/bertin-gpt-j-6B) for Spanish. """.strip() FOOTER = """
""".strip() asr_model_name_es = "jonatasgrosman/wav2vec2-large-xlsr-53-spanish" model_instance_es = AutoModelForCTC.from_pretrained(asr_model_name_es, use_auth_token=HF_AUTH_TOKEN) processor_es = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model_name_es, use_auth_token=HF_AUTH_TOKEN) asr_es = pipeline( "automatic-speech-recognition", model=model_instance_es, tokenizer=processor_es.tokenizer, feature_extractor=processor_es.feature_extractor, decoder=processor_es.decoder ) tts_model_name = "facebook/tts_transformer-es-css10" speak_es = gr.Interface.load(f"huggingface/{tts_model_name}", api_key=HF_AUTH_TOKEN) transcribe_es = lambda input_file: asr_es(input_file, chunk_length_s=5, stride_length_s=1)["text"] def generate_es(text, **kwargs): # text="Promtp", max_length=100, top_k=100, top_p=50, temperature=0.95, do_sample=True, do_clean=True api_uri = "https://hf.space/embed/bertin-project/bertin-gpt-j-6B/+/api/predict/" response = requests.post(api_uri, data=json.dumps({"data": [text, kwargs["max_length"], 100, 50, 0.95, True, True]})) if response.ok: if DEBUG: print("Spanish response >", response.json()) return response.json()["data"][0] else: return "" asr_model_name_en = "jonatasgrosman/wav2vec2-large-xlsr-53-english" model_instance_en = AutoModelForCTC.from_pretrained(asr_model_name_en) processor_en = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model_name_en) asr_en = pipeline( "automatic-speech-recognition", model=model_instance_en, tokenizer=processor_en.tokenizer, feature_extractor=processor_en.feature_extractor, decoder=processor_en.decoder ) tts_model_name = "facebook/fastspeech2-en-ljspeech" speak_en = gr.Interface.load(f"huggingface/{tts_model_name}", api_key=HF_AUTH_TOKEN) transcribe_en = lambda input_file: asr_en(input_file, chunk_length_s=5, stride_length_s=1)["text"] # generate_iface = gr.Interface.load("huggingface/EleutherAI/gpt-j-6B", api_key=HF_AUTH_TOKEN) empty_audio = 'empty.flac' sf.write(empty_audio, [], 16000) deuncase = gr.Interface.load("huggingface/pere/DeUnCaser", api_key=HF_AUTH_TOKEN) def generate_en(text, **kwargs): api_uri = "https://api.eleuther.ai/completion" #--data-raw '{"context":"Promtp","top_p":0.9,"temp":0.8,"response_length":128,"remove_input":true}' response = requests.post(api_uri, data=json.dumps({"context": text, "top_p": 0.9, "temp": 0.8, "response_length": kwargs["max_length"], "remove_input": True})) if response.ok: if DEBUG: print("English response >", response.json()) return response.json()[0]["generated_text"].lstrip() else: return "" def select_lang(lang): if lang.lower() == "spanish": return generate_es, transcribe_es, speak_es else: return generate_en, transcribe_en, speak_en def select_lang_vars(lang): if lang.lower() == "spanish": AGENT = "BERTIN" USER = "ENTREVISTADOR" CONTEXT = """La siguiente conversación es un extracto de una entrevista a {AGENT} celebrada en Madrid para Radio Televisión Española: {USER}: Bienvenido, {AGENT}. Un placer tenerlo hoy con nosotros. {AGENT}: Gracias. El placer es mío.""" else: AGENT = "ELEUTHER" USER = "INTERVIEWER" CONTEXT = """The next conversation is an excerpt from an interview to {AGENT} that appeared in the New York Times: {USER}: Welcome, {AGENT}. It is a pleasure to have you here today. {AGENT}: Thanks. The pleasure is mine.""" return AGENT, USER, CONTEXT def format_chat(history): interventions = [] for user, bot in history: interventions.append(f"""
{user}
{bot}
""") return f"""
Conversation log
{"".join(interventions)}
""" def chat_with_gpt(lang, agent, user, context, audio_in, history): if not audio_in: return history, history, empty_audio, format_chat(history) generate, transcribe, speak = select_lang(lang) AGENT, USER, _ = select_lang_vars(lang) user_message = deuncase(transcribe(audio_in)) # agent = AGENT # user = USER generation_kwargs = { "max_length": 50, # "top_k": top_k, # "top_p": top_p, # "temperature": temperature, # "do_sample": do_sample, # "do_clean": do_clean, # "num_return_sequences": 1, # "return_full_text": False, } message = user_message.split(" ", 1)[0].capitalize() + " " + user_message.split(" ", 1)[-1] history = history or [] #[(f"{user}: Bienvenido. Encantado de tenerle con nosotros.", f"{agent}: Un placer, muchas gracias por la invitación.")] context = context.format(USER=user or USER, AGENT=agent or AGENT).strip() if context[-1] not in ".:": context += "." context_length = len(context.split()) history_take = 0 history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:]) while len(history_context.split()) > MAX_LENGTH - (generation_kwargs["max_length"] + context_length): history_take += 1 history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:]) if history_take >= MAX_LENGTH: break context += history_context for _ in range(5): prompt = f"{context}\n\n{user}: {message}.\n" response = generate(prompt, context_length=context_length, **generation_kwargs) if DEBUG: print("\n-----\n" + response + "\n-----\n") # response = response.split("\n")[-1] # if agent in response and response.split(agent)[-1]: # response = response.split(agent)[-1] # if user in response and response.split(user)[-1]: # response = response.split(user)[-1] # Take the first response response = [ r for r in response.replace(prompt, "").split(f"{AGENT}:") if r.strip() ][0].split(USER)[0].replace(f"{AGENT}:", "\n").strip() if response and response[0] in string.punctuation: response = response[1:].strip() if response.strip().startswith(f"{user}: {message}"): response = response.strip().split(f"{user}: {message}")[-1] if response.replace(".", "").strip() and message.replace(".", "").strip() != response.replace(".", "").strip(): break if DEBUG: print() print("CONTEXT:") print(context) print() print("MESSAGE") print(message) print() print("RESPONSE:") print(response) if not response.strip(): response = "Lo siento, no puedo hablar ahora" if lang.lower() == "Spanish" else "Sorry, can't talk right now" history.append((user_message, response)) return history, history, speak(response), format_chat(history) with gr.Blocks() as demo: gr.Markdown(HEADER) lang = gr.Radio(label="Language", choices=["English", "Spanish"], value=DEFAULT_LANG, type="value") AGENT, USER, CONTEXT = select_lang_vars(DEFAULT_LANG) context = gr.Textbox(label="Context", lines=5, value=CONTEXT) with gr.Row(): audio_in = gr.Audio(label="User", source="microphone", type="filepath") audio_out = gr.Audio(label="Agent", interactive=False, value=empty_audio) # chat_btn = gr.Button("Submit") with gr.Row(): user = gr.Textbox(label="User", value=USER) agent = gr.Textbox(label="Agent", value=AGENT) lang.change(select_lang_vars, inputs=[lang], outputs=[agent, user, context]) history = gr.Variable(value=[]) chatbot = gr.Variable() # gr.Chatbot(color_map=("green", "gray"), visible=False) # chat_btn.click(chat_with_gpt, inputs=[lang, agent, user, context, audio_in, history], outputs=[chatbot, history, audio_out]) log = gr.HTML() audio_in.change(chat_with_gpt, inputs=[lang, agent, user, context, audio_in, history], outputs=[chatbot, history, audio_out, log]) gr.Markdown(FOOTER) demo.launch()