versae commited on
Commit
15c5469
1 Parent(s): 38c5edb

Fix spekars order and agent response

Browse files
Files changed (2) hide show
  1. duplex.py +221 -0
  2. gradio_app.py +10 -6
duplex.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import string
5
+
6
+ import numpy as np
7
+ import gradio as gr
8
+ import requests
9
+ import soundfile as sf
10
+
11
+ from transformers import pipeline, set_seed
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+ import logging
14
+
15
+ import sys
16
+ import gradio as gr
17
+ from transformers import pipeline, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
18
+
19
+ DEBUG = os.environ.get("DEBUG", "false")[0] in "ty1"
20
+ MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
21
+ DEFAULT_LANG = os.environ.get("DEFAULT_LANG", "English")
22
+ HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None)
23
+
24
+ HEADER = """
25
+ # Poor Man's Duplex
26
+
27
+ Talk to a language model like you talk on a Walkie-Talkie! Well, with larger latencies.
28
+ The models are [EleutherAI's GPT-J-6B](https://huggingface.co/EleutherAI/gpt-j-6B) for English, and [BERTIN GPT-J-6B](https://huggingface.co/bertin-project/bertin-gpt-j-6B) for Spanish.
29
+ """.strip()
30
+
31
+ FOOTER = """
32
+ <div align=center>
33
+ <img src="https://visitor-badge.glitch.me/badge?page_id=versae/poor-mans-duplex"/>
34
+ <div align=center>
35
+ """.strip()
36
+
37
+ asr_model_name_es = "jonatasgrosman/wav2vec2-large-xlsr-53-spanish"
38
+ model_instance_es = AutoModelForCTC.from_pretrained(asr_model_name_es, use_auth_token=HF_AUTH_TOKEN)
39
+ processor_es = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model_name_es, use_auth_token=HF_AUTH_TOKEN)
40
+ asr_es = pipeline(
41
+ "automatic-speech-recognition",
42
+ model=model_instance_es,
43
+ tokenizer=processor_es.tokenizer,
44
+ feature_extractor=processor_es.feature_extractor,
45
+ decoder=processor_es.decoder
46
+ )
47
+ tts_model_name = "facebook/tts_transformer-es-css10"
48
+ speak_es = gr.Interface.load(f"huggingface/{tts_model_name}", api_key=HF_AUTH_TOKEN)
49
+ transcribe_es = lambda input_file: asr_es(input_file, chunk_length_s=5, stride_length_s=1)["text"]
50
+ def generate_es(text, **kwargs):
51
+ # text="Promtp", max_length=100, top_k=100, top_p=50, temperature=0.95, do_sample=True, do_clean=True
52
+ api_uri = "https://hf.space/embed/bertin-project/bertin-gpt-j-6B/+/api/predict/"
53
+ response = requests.post(api_uri, data=json.dumps({"data": [text, kwargs["max_length"], 100, 50, 0.95, True, True]}))
54
+ if response.ok:
55
+ if DEBUG:
56
+ print("Spanish response >", response.json())
57
+ return response.json()["data"][0]
58
+ else:
59
+ return ""
60
+
61
+ asr_model_name_en = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
62
+ model_instance_en = AutoModelForCTC.from_pretrained(asr_model_name_en)
63
+ processor_en = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model_name_en)
64
+ asr_en = pipeline(
65
+ "automatic-speech-recognition",
66
+ model=model_instance_en,
67
+ tokenizer=processor_en.tokenizer,
68
+ feature_extractor=processor_en.feature_extractor,
69
+ decoder=processor_en.decoder
70
+ )
71
+ tts_model_name = "facebook/fastspeech2-en-ljspeech"
72
+ speak_en = gr.Interface.load(f"huggingface/{tts_model_name}", api_key=HF_AUTH_TOKEN)
73
+ transcribe_en = lambda input_file: asr_en(input_file, chunk_length_s=5, stride_length_s=1)["text"]
74
+ # generate_iface = gr.Interface.load("huggingface/EleutherAI/gpt-j-6B", api_key=HF_AUTH_TOKEN)
75
+
76
+ empty_audio = 'empty.flac'
77
+ sf.write(empty_audio, [], 16000)
78
+ deuncase = gr.Interface.load("huggingface/pere/DeUnCaser", api_key=HF_AUTH_TOKEN)
79
+
80
+ def generate_en(text, **kwargs):
81
+ api_uri = "https://api.eleuther.ai/completion"
82
+ #--data-raw '{"context":"Promtp","top_p":0.9,"temp":0.8,"response_length":128,"remove_input":true}'
83
+ response = requests.post(api_uri, data=json.dumps({"context": text, "top_p": 0.9, "temp": 0.8, "response_length": kwargs["max_length"], "remove_input": True}))
84
+ if response.ok:
85
+ if DEBUG:
86
+ print("English response >", response.json())
87
+ return response.json()[0]["generated_text"].lstrip()
88
+ else:
89
+ return ""
90
+
91
+
92
+ def select_lang(lang):
93
+ if lang.lower() == "spanish":
94
+ return generate_es, transcribe_es, speak_es
95
+ else:
96
+ return generate_en, transcribe_en, speak_en
97
+
98
+
99
+ def select_lang_vars(lang):
100
+ if lang.lower() == "spanish":
101
+ AGENT = "BERTIN"
102
+ USER = "ENTREVISTADOR"
103
+ CONTEXT = """La siguiente conversación es un extracto de una entrevista a {AGENT} celebrada en Madrid para Radio Televisión Española:
104
+
105
+ {USER}: Bienvenido, {AGENT}. Un placer tenerlo hoy con nosotros.
106
+ {AGENT}: Gracias. El placer es mío."""
107
+ else:
108
+ AGENT = "ELEUTHER"
109
+ USER = "INTERVIEWER"
110
+ CONTEXT = """The next conversation is an excerpt from an interview to {AGENT} that appeared in the New York Times:
111
+
112
+ {USER}: Welcome, {AGENT}. It is a pleasure to have you here today.
113
+ {AGENT}: Thanks. The pleasure is mine."""
114
+
115
+ return AGENT, USER, CONTEXT
116
+
117
+
118
+ def format_chat(history):
119
+ interventions = []
120
+ for user, bot in history:
121
+ interventions.append(f"""
122
+ <div data-testid="user" style="background-color:#16a34a" class="px-3 py-2 rounded-[22px] rounded-bl-none place-self-start text-white ml-7 text-sm">{user}</div>
123
+ <div data-testid="bot" style="background-color:gray" class="px-3 py-2 rounded-[22px] rounded-br-none text-white ml-7 text-sm">{bot}</div>
124
+ """)
125
+ return f"""<details><summary>Conversation log</summary>
126
+ <div class="overflow-y-auto h-[40vh]">
127
+ <div class="flex flex-col items-end space-y-4 p-3">
128
+ {"".join(interventions)}
129
+ </div>
130
+ </div>
131
+ </summary>"""
132
+
133
+
134
+ def chat_with_gpt(lang, agent, user, context, audio_in, history):
135
+ if not audio_in:
136
+ return history, history, empty_audio, format_chat(history)
137
+ generate, transcribe, speak = select_lang(lang)
138
+ AGENT, USER, _ = select_lang_vars(lang)
139
+ user_message = deuncase(transcribe(audio_in))
140
+ # agent = AGENT
141
+ # user = USER
142
+ generation_kwargs = {
143
+ "max_length": 50,
144
+ # "top_k": top_k,
145
+ # "top_p": top_p,
146
+ # "temperature": temperature,
147
+ # "do_sample": do_sample,
148
+ # "do_clean": do_clean,
149
+ # "num_return_sequences": 1,
150
+ # "return_full_text": False,
151
+ }
152
+ message = user_message.split(" ", 1)[0].capitalize() + " " + user_message.split(" ", 1)[-1]
153
+ history = history or [] #[(f"{user}: Bienvenido. Encantado de tenerle con nosotros.", f"{agent}: Un placer, muchas gracias por la invitación.")]
154
+ context = context.format(USER=user or USER, AGENT=agent or AGENT).strip()
155
+ if context[-1] not in ".:":
156
+ context += "."
157
+ context_length = len(context.split())
158
+ history_take = 0
159
+ history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:])
160
+ while len(history_context.split()) > MAX_LENGTH - (generation_kwargs["max_length"] + context_length):
161
+ history_take += 1
162
+ history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:])
163
+ if history_take >= MAX_LENGTH:
164
+ break
165
+ context += history_context
166
+ for _ in range(5):
167
+ response = generate(f"{context}\n\n{user}: {message}.\n", context_length=context_length, **generation_kwargs)
168
+ if DEBUG:
169
+ print("\n-----" + response + "-----\n")
170
+ # response = response.split("\n")[-1]
171
+ # if agent in response and response.split(agent)[-1]:
172
+ # response = response.split(agent)[-1]
173
+ # if user in response and response.split(user)[-1]:
174
+ # response = response.split(user)[-1]
175
+ # Take the first response
176
+ response = [
177
+ r for r in response.split(f"{AGENT}:") if r.strip()
178
+ ][0].split(USER)[0].replace(f"{AGENT}:", "\n").strip()
179
+ if response and response[0] in string.punctuation:
180
+ response = response[1:].strip()
181
+ if response.strip().startswith(f"{user}: {message}"):
182
+ response = response.strip().split(f"{user}: {message}")[-1]
183
+ if response.replace(".", "").strip() and message.replace(".", "").strip() != response.replace(".", "").strip():
184
+ break
185
+ if DEBUG:
186
+ print()
187
+ print("CONTEXT:")
188
+ print(context)
189
+ print()
190
+ print("MESSAGE")
191
+ print(message)
192
+ print()
193
+ print("RESPONSE:")
194
+ print(response)
195
+ if not response.strip():
196
+ response = "Lo siento, no puedo hablar ahora" if lang.lower() == "Spanish" else "Sorry, can't talk right now"
197
+ history.append((user_message, response))
198
+ return history, history, speak(response), format_chat(history)
199
+
200
+
201
+ with gr.Blocks() as demo:
202
+ gr.Markdown(HEADER)
203
+ lang = gr.Radio(label="Language", choices=["English", "Spanish"], value=DEFAULT_LANG, type="value")
204
+ AGENT, USER, CONTEXT = select_lang_vars(DEFAULT_LANG)
205
+ context = gr.Textbox(label="Context", lines=5, value=CONTEXT)
206
+ with gr.Row():
207
+ audio_in = gr.Audio(label="User", source="microphone", type="filepath")
208
+ audio_out = gr.Audio(label="Agent", interactive=False, value=empty_audio)
209
+ # chat_btn = gr.Button("Submit")
210
+ with gr.Row():
211
+ user = gr.Textbox(label="User", value=USER)
212
+ agent = gr.Textbox(label="Agent", value=AGENT)
213
+ lang.change(select_lang_vars, inputs=[lang], outputs=[agent, user, context])
214
+ history = gr.Variable(value=[])
215
+ chatbot = gr.Variable() # gr.Chatbot(color_map=("green", "gray"), visible=False)
216
+ # chat_btn.click(chat_with_gpt, inputs=[lang, agent, user, context, audio_in, history], outputs=[chatbot, history, audio_out])
217
+ log = gr.HTML()
218
+ audio_in.change(chat_with_gpt, inputs=[lang, agent, user, context, audio_in, history], outputs=[chatbot, history, audio_out, log])
219
+ gr.Markdown(FOOTER)
220
+
221
+ demo.launch()
gradio_app.py CHANGED
@@ -230,7 +230,7 @@ def expand_with_gpt(hidden, text, max_length, top_k, top_p, temperature, do_samp
230
  }
231
  return generator.generate(text, generation_kwargs, previous_text=hidden)
232
 
233
- def chat_with_gpt(user, agent, context, user_message, history, max_length, top_k, top_p, temperature, do_sample, do_clean):
234
  # agent = AGENT
235
  # user = USER
236
  generation_kwargs = {
@@ -261,11 +261,15 @@ def chat_with_gpt(user, agent, context, user_message, history, max_length, top_k
261
  response = generator.generate(f"{context}\n\n{user}: {message}.\n", generation_kwargs)[0]
262
  if DEBUG:
263
  print("\n-----" + response + "-----\n")
264
- response = response.split("\n")[-1]
265
- if agent in response and response.split(agent)[-1]:
266
- response = response.split(agent)[-1]
267
- if user in response and response.split(user)[-1]:
268
- response = response.split(user)[-1]
 
 
 
 
269
  if response[0] in string.punctuation:
270
  response = response[1:].strip()
271
  if response.strip().startswith(f"{user}: {message}"):
 
230
  }
231
  return generator.generate(text, generation_kwargs, previous_text=hidden)
232
 
233
+ def chat_with_gpt(agent, user, context, user_message, history, max_length, top_k, top_p, temperature, do_sample, do_clean):
234
  # agent = AGENT
235
  # user = USER
236
  generation_kwargs = {
 
261
  response = generator.generate(f"{context}\n\n{user}: {message}.\n", generation_kwargs)[0]
262
  if DEBUG:
263
  print("\n-----" + response + "-----\n")
264
+ # response = response.split("\n")[-1]
265
+ # if agent in response and response.split(agent)[-1]:
266
+ # response = response.split(agent)[-1]
267
+ # if user in response and response.split(user)[-1]:
268
+ # response = response.split(user)[-1]
269
+ # Take the first response
270
+ response = [
271
+ r for r in response.split(f"{AGENT}:") if r.strip()
272
+ ][0].split(USER)[0].replace(f"{AGENT}:", "\n").strip()
273
  if response[0] in string.punctuation:
274
  response = response[1:].strip()
275
  if response.strip().startswith(f"{user}: {message}"):