Rifky commited on
Commit
df030ab
1 Parent(s): 68ff7b5

added reset button, create new session button, and change session

Browse files
Files changed (1) hide show
  1. app.py +106 -3
app.py CHANGED
@@ -8,8 +8,11 @@ from io import BytesIO
8
  from pydub import AudioSegment
9
 
10
 
 
 
 
11
  def create_chat_session():
12
- r = requests.post("http://121.176.153.117:5000/create")
13
 
14
  if (r.status_code != 201):
15
  raise Exception("Failed to create chat session")
@@ -23,6 +26,19 @@ def create_chat_session():
23
  session_id = create_chat_session()
24
  chat_history = []
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def add_text(history, text):
28
  history = history + [(text, None)]
@@ -43,7 +59,7 @@ def add_audio(history, audio):
43
  history = history + [((f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",), None)]
44
 
45
  response = requests.post(
46
- "http://121.176.153.117:5000/transcribe",
47
  files={'audio': audio_file.getvalue()}
48
  )
49
 
@@ -56,6 +72,21 @@ def add_audio(history, audio):
56
 
57
  return history, gr.update(value="", interactive=False)
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def bot(history):
61
  if type(history[-1][0]) == str:
@@ -64,7 +95,7 @@ def bot(history):
64
  message = history[-2][0]
65
 
66
  response = requests.post(
67
- f"http://121.176.153.117:5000/send/text/{session_id}",
68
  headers={'Content-type': 'application/json'},
69
  json={
70
  'message': message,
@@ -93,6 +124,52 @@ def bot(history):
93
 
94
  return history
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def load_chat_history(history):
97
  global chat_history
98
  if len(chat_history) > len(history):
@@ -101,6 +178,18 @@ def load_chat_history(history):
101
 
102
 
103
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
104
  chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750)
105
 
106
  demo.load(load_chat_history, [chatbot], [chatbot], queue=False)
@@ -116,6 +205,12 @@ with gr.Blocks() as demo:
116
  source="microphone", type="numpy", show_label=False, format="mp3"
117
  ).style(container=False)
118
 
 
 
 
 
 
 
119
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
120
  bot, chatbot, chatbot
121
  )
@@ -126,4 +221,12 @@ with gr.Blocks() as demo:
126
  )
127
  audio_msg.then(lambda: gr.update(interactive=True, value=None), None, [audio], queue=False)
128
 
 
 
 
 
 
 
 
 
129
  demo.launch(show_error=True)
 
8
  from pydub import AudioSegment
9
 
10
 
11
+ LOCAL_API_ENDPOINT = "http://localhost:5000"
12
+ PUBLIC_API_ENDPOINT = "http://121.176.153.117:5000"
13
+
14
  def create_chat_session():
15
+ r = requests.post(LOCAL_API_ENDPOINT + "/create")
16
 
17
  if (r.status_code != 201):
18
  raise Exception("Failed to create chat session")
 
26
  session_id = create_chat_session()
27
  chat_history = []
28
 
29
+ def create_new_or_change_session(history, id):
30
+ global session_id
31
+ global chat_history
32
+
33
+ if id == "":
34
+ session_id = create_chat_session()
35
+ history = []
36
+ else:
37
+ history, _ = change_session(history, id)
38
+
39
+ chat_history = history
40
+
41
+ return history, gr.update(value="", interactive=False)
42
 
43
  def add_text(history, text):
44
  history = history + [(text, None)]
 
59
  history = history + [((f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",), None)]
60
 
61
  response = requests.post(
62
+ LOCAL_API_ENDPOINT + "/transcribe",
63
  files={'audio': audio_file.getvalue()}
64
  )
65
 
 
72
 
73
  return history, gr.update(value="", interactive=False)
74
 
75
+ def reset_chat_session(history):
76
+ global session_id
77
+ global chat_history
78
+
79
+ response = requests.post(
80
+ LOCAL_API_ENDPOINT + f"/reset/{session_id}"
81
+ )
82
+
83
+ if (response.status_code != 200):
84
+ raise Exception(response.text)
85
+
86
+ history = []
87
+ chat_history = []
88
+
89
+ return history
90
 
91
  def bot(history):
92
  if type(history[-1][0]) == str:
 
95
  message = history[-2][0]
96
 
97
  response = requests.post(
98
+ LOCAL_API_ENDPOINT + f"/send/text/{session_id}",
99
  headers={'Content-type': 'application/json'},
100
  json={
101
  'message': message,
 
124
 
125
  return history
126
 
127
+ def change_session(history, id):
128
+ global session_id
129
+ global chat_history
130
+
131
+ response = requests.get(
132
+ LOCAL_API_ENDPOINT + f"/{id}"
133
+ )
134
+
135
+ if (response.status_code != 200):
136
+ raise Exception(response.text)
137
+
138
+ response = response.json()
139
+
140
+ session_id = id
141
+
142
+ history = []
143
+
144
+ try:
145
+ for chat in response:
146
+ if chat['role'] == 'user':
147
+ if chat['audio'] != "":
148
+ audio_bytes = base64.b64decode(chat['audio'].encode('utf-8'))
149
+ audio_file = BytesIO(audio_bytes)
150
+ audio_id = secrets.token_hex(8)
151
+ AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3")
152
+ history = history + [((f"temp_audio/{id}/audio_input_{audio_id}.mp3",), None)]
153
+ history = history + [(chat['message'], None)]
154
+ elif chat['role'] == 'assistant':
155
+ audio_bytes = base64.b64decode(chat['audio'].encode('utf-8'))
156
+ audio_file = BytesIO(audio_bytes)
157
+ audio_id = secrets.token_hex(8)
158
+ AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3")
159
+
160
+ history = history + [(None, (f"temp_audio/{id}/audio_input_{audio_id}.mp3",))]
161
+ history = history + [(None, chat['message'])]
162
+ else:
163
+ raise Exception("Invalid chat role")
164
+ except Exception as e:
165
+ raise Exception(f"Response: {response}")
166
+
167
+ chat_history = history.copy()
168
+
169
+ print(f"len(chat_history): {len(chat_history)}\nlen(history): {len(history)}\nlen(response): {len(response)}")
170
+
171
+ return history, gr.update(value="", interactive=False)
172
+
173
  def load_chat_history(history):
174
  global chat_history
175
  if len(chat_history) > len(history):
 
178
 
179
 
180
  with gr.Blocks() as demo:
181
+ with gr.Row():
182
+ # change session id
183
+ change_session_txt = gr.Textbox(
184
+ show_label=False,
185
+ placeholder=session_id,
186
+ ).style(container=False)
187
+ with gr.Row():
188
+ # button to create new or change session id
189
+ change_session_button = gr.Button(
190
+ "Create new or change session", type='success', size="sm"
191
+ ).style(margin="0 10px 0 0", container=False)
192
+
193
  chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750)
194
 
195
  demo.load(load_chat_history, [chatbot], [chatbot], queue=False)
 
205
  source="microphone", type="numpy", show_label=False, format="mp3"
206
  ).style(container=False)
207
 
208
+
209
+ with gr.Row():
210
+ reset_button = gr.Button(
211
+ "Reset Chat Session", type='stop', size="sm"
212
+ ).style(margin="0 10px 0 0", container=False)
213
+
214
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
215
  bot, chatbot, chatbot
216
  )
 
221
  )
222
  audio_msg.then(lambda: gr.update(interactive=True, value=None), None, [audio], queue=False)
223
 
224
+ reset_button.click(reset_chat_session, [chatbot], [chatbot], queue=False)
225
+
226
+ chgn_msg = change_session_txt.submit(change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False)
227
+ chgn_msg.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False)
228
+
229
+ create_new_or_change_session_btn = change_session_button.click(create_new_or_change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False)
230
+ create_new_or_change_session_btn.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False)
231
+
232
  demo.launch(show_error=True)