YingxuHe commited on
Commit
8c2ea06
·
1 Parent(s): a31a6df

add logging function

Browse files
src/content/common.py CHANGED
@@ -4,7 +4,8 @@ import numpy as np
4
  import streamlit as st
5
 
6
  from src.tunnel import start_server
7
- from src.generation import FIXED_GENERATION_CONFIG, load_model
 
8
 
9
 
10
  DEFAULT_DIALOGUE_STATES = dict(
@@ -308,6 +309,10 @@ def init_state_section():
308
  unsafe_allow_html=True
309
  )
310
 
 
 
 
 
311
  if "server" not in st.session_state:
312
  st.session_state.server = start_server()
313
 
@@ -328,7 +333,7 @@ def header_section(component_name="Playground", icon="🤖"):
328
  f"<h1 style='text-align: center;'>MERaLiON-AudioLLM {component_name} {icon}</h1>",
329
  unsafe_allow_html=True
330
  )
331
-
332
  st.markdown(
333
  f"""<div class="main-intro-normal-window">
334
  <p>This {component_name.lower()} is based on
@@ -336,7 +341,8 @@ def header_section(component_name="Playground", icon="🤖"):
336
  target="_blank" rel="noopener noreferrer"> MERaLiON-AudioLLM</a>,
337
  developed by I2R, A*STAR, in collaboration with AISG, Singapore.
338
  It is tailored for Singapore’s multilingual and multicultural landscape.
339
- MERaLiON-AudioLLM supports <strong>Automatic Speech Recognation</strong>,
 
340
  <strong>Speech Translation</strong>,
341
  <strong>Spoken Question Answering</strong>,
342
  <strong>Spoken Dialogue Summarization</strong>,
@@ -356,9 +362,9 @@ def header_section(component_name="Playground", icon="🤖"):
356
 
357
  @st.fragment
358
  def sidebar_fragment():
359
- with st.container(height=300, border=False):
360
- st.page_link("pages/playground.py", disabled=st.session_state.disprompt, label="Playground")
361
- st.page_link("pages/voice_chat.py", disabled=st.session_state.disprompt, label="Voice Chat (experimental)")
362
 
363
 
364
  st.divider()
@@ -367,4 +373,47 @@ def sidebar_fragment():
367
 
368
  st.slider(label='Top P', min_value=0.0, max_value=1.0, value=0.9, key='top_p')
369
 
370
- st.slider(label="Repetition Penalty", min_value=1.0, max_value=1.2, value=1.1, key="repetition_penalty")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import streamlit as st
5
 
6
  from src.tunnel import start_server
7
+ from src.generation import FIXED_GENERATION_CONFIG, load_model, retrive_response
8
+ from src.logger import load_logger
9
 
10
 
11
  DEFAULT_DIALOGUE_STATES = dict(
 
309
  unsafe_allow_html=True
310
  )
311
 
312
+ if "logger" not in st.session_state:
313
+ st.session_state.logger = load_logger()
314
+ st.session_state.session_id = st.session_state.logger.register_session()
315
+
316
  if "server" not in st.session_state:
317
  st.session_state.server = start_server()
318
 
 
333
  f"<h1 style='text-align: center;'>MERaLiON-AudioLLM {component_name} {icon}</h1>",
334
  unsafe_allow_html=True
335
  )
336
+
337
  st.markdown(
338
  f"""<div class="main-intro-normal-window">
339
  <p>This {component_name.lower()} is based on
 
341
  target="_blank" rel="noopener noreferrer"> MERaLiON-AudioLLM</a>,
342
  developed by I2R, A*STAR, in collaboration with AISG, Singapore.
343
  It is tailored for Singapore’s multilingual and multicultural landscape.
344
+ MERaLiON-AudioLLM supports
345
+ <strong>Automatic Speech Recognation</strong>,
346
  <strong>Speech Translation</strong>,
347
  <strong>Spoken Question Answering</strong>,
348
  <strong>Spoken Dialogue Summarization</strong>,
 
362
 
363
  @st.fragment
364
  def sidebar_fragment():
365
+ with st.container(height=256, border=False):
366
+ st.page_link("pages/playground.py", disabled=st.session_state.disprompt, label="🚀 Playground")
367
+ st.page_link("pages/voice_chat.py", disabled=st.session_state.disprompt, label="🗣️ Voice Chat (experimental)")
368
 
369
 
370
  st.divider()
 
373
 
374
  st.slider(label='Top P', min_value=0.0, max_value=1.0, value=0.9, key='top_p')
375
 
376
+ st.slider(label="Repetition Penalty", min_value=1.0, max_value=1.2, value=1.1, key="repetition_penalty")
377
+
378
+
379
+ def retrive_response_with_ui(prompt, array_audio, base64_audio, stream):
380
+ generation_params = dict(
381
+ model=st.session_state.model_name,
382
+ max_completion_tokens=st.session_state.max_completion_tokens,
383
+ temperature=st.session_state.temperature,
384
+ top_p=st.session_state.top_p,
385
+ extra_body={
386
+ "repetition_penalty": st.session_state.repetition_penalty,
387
+ "top_k": st.session_state.top_k,
388
+ "length_penalty": st.session_state.length_penalty
389
+ },
390
+ seed=st.session_state.seed
391
+ )
392
+
393
+ error_msg, warnings, stream = retrive_response(
394
+ prompt,
395
+ array_audio,
396
+ base64_audio,
397
+ params=generation_params,
398
+ stream=True
399
+ )
400
+ response = ""
401
+
402
+ if error_msg:
403
+ st.error(error_msg)
404
+ for warning_msg in warnings:
405
+ st.warning(warning_msg)
406
+ if stream:
407
+ response = st.write_stream(stream)
408
+
409
+ st.session_state.logger.register_query(
410
+ session_id=st.session_state.session_id,
411
+ base64_audio=base64_audio,
412
+ text_input=prompt,
413
+ params=generation_params,
414
+ response=response,
415
+ warnings=warnings,
416
+ error_msg=error_msg
417
+ )
418
+
419
+ return error_msg, warnings, response
src/content/playground.py CHANGED
@@ -3,14 +3,15 @@ import base64
3
 
4
  import streamlit as st
5
 
6
- from src.generation import MAX_AUDIO_LENGTH, retrive_response
7
  from src.utils import bytes_to_array, array_to_bytes
8
  from src.content.common import (
9
  AUDIO_SAMPLES_W_INSTRUCT,
10
  DEFAULT_DIALOGUE_STATES,
11
  init_state_section,
12
  header_section,
13
- sidebar_fragment
 
14
  )
15
 
16
 
@@ -173,20 +174,12 @@ def conversation_section():
173
 
174
  with st.chat_message("assistant"):
175
  with st.spinner("Thinking..."):
176
- error_msg, warnings, stream = retrive_response(
177
  one_time_prompt,
178
  st.session_state.pg_audio_array,
179
  st.session_state.pg_audio_base64,
180
  stream=True
181
  )
182
- response = ""
183
-
184
- if error_msg:
185
- st.error(error_msg)
186
- for warning_msg in warnings:
187
- st.warning(warning_msg)
188
- if stream:
189
- response = st.write_stream(stream)
190
 
191
  st.session_state.pg_messages.append({
192
  "role": "assistant",
 
3
 
4
  import streamlit as st
5
 
6
+ from src.generation import MAX_AUDIO_LENGTH
7
  from src.utils import bytes_to_array, array_to_bytes
8
  from src.content.common import (
9
  AUDIO_SAMPLES_W_INSTRUCT,
10
  DEFAULT_DIALOGUE_STATES,
11
  init_state_section,
12
  header_section,
13
+ sidebar_fragment,
14
+ retrive_response_with_ui
15
  )
16
 
17
 
 
174
 
175
  with st.chat_message("assistant"):
176
  with st.spinner("Thinking..."):
177
+ error_msg, warnings, response = retrive_response_with_ui(
178
  one_time_prompt,
179
  st.session_state.pg_audio_array,
180
  st.session_state.pg_audio_base64,
181
  stream=True
182
  )
 
 
 
 
 
 
 
 
183
 
184
  st.session_state.pg_messages.append({
185
  "role": "assistant",
src/content/voice_chat.py CHANGED
@@ -4,13 +4,14 @@ import base64
4
  import numpy as np
5
  import streamlit as st
6
 
7
- from src.generation import retrive_response
8
  from src.utils import bytes_to_array, array_to_bytes
9
  from src.content.common import (
10
  DEFAULT_DIALOGUE_STATES,
11
  init_state_section,
12
  header_section,
13
- sidebar_fragment
 
14
  )
15
 
16
 
@@ -18,9 +19,6 @@ from src.content.common import (
18
  DEFAULT_PROMPT = "Please follow the instruction in the speech."
19
 
20
 
21
- MAX_AUDIO_LENGTH = 120
22
-
23
-
24
  def _update_audio(audio_bytes):
25
  origin_audio_array = bytes_to_array(audio_bytes)
26
  truncated_audio_array = origin_audio_array[: MAX_AUDIO_LENGTH*16000]
@@ -123,20 +121,12 @@ def conversation_section():
123
 
124
  with st.chat_message("assistant"):
125
  with st.spinner("Thinking..."):
126
- error_msg, warnings, stream = retrive_response(
127
  one_time_prompt,
128
  one_time_array,
129
  one_time_base64,
130
  stream=True
131
  )
132
- response = ""
133
-
134
- if error_msg:
135
- st.error(error_msg)
136
- for warning_msg in warnings:
137
- st.warning(warning_msg)
138
- if stream:
139
- response = st.write_stream(stream)
140
 
141
  st.session_state.vc_messages.append({
142
  "role": "assistant",
 
4
  import numpy as np
5
  import streamlit as st
6
 
7
+ from src.generation import MAX_AUDIO_LENGTH
8
  from src.utils import bytes_to_array, array_to_bytes
9
  from src.content.common import (
10
  DEFAULT_DIALOGUE_STATES,
11
  init_state_section,
12
  header_section,
13
+ sidebar_fragment,
14
+ retrive_response_with_ui
15
  )
16
 
17
 
 
19
  DEFAULT_PROMPT = "Please follow the instruction in the speech."
20
 
21
 
 
 
 
22
  def _update_audio(audio_bytes):
23
  origin_audio_array = bytes_to_array(audio_bytes)
24
  truncated_audio_array = origin_audio_array[: MAX_AUDIO_LENGTH*16000]
 
121
 
122
  with st.chat_message("assistant"):
123
  with st.spinner("Thinking..."):
124
+ error_msg, warnings, response = retrive_response_with_ui(
125
  one_time_prompt,
126
  one_time_array,
127
  one_time_base64,
128
  stream=True
129
  )
 
 
 
 
 
 
 
 
130
 
131
  st.session_state.vc_messages.append({
132
  "role": "assistant",
src/generation.py CHANGED
@@ -64,7 +64,7 @@ def _retrive_response(text_input: str, base64_audio_input: str, **kwargs):
64
  )
65
 
66
 
67
- def _retry_retrive_response_throws_exception(text_input, base64_audio_input, stream=False, retry=3):
68
  if not base64_audio_input:
69
  raise NoAudioException("audio is empty.")
70
 
@@ -72,17 +72,8 @@ def _retry_retrive_response_throws_exception(text_input, base64_audio_input, str
72
  response_object = _retrive_response(
73
  text_input=text_input,
74
  base64_audio_input=base64_audio_input,
75
- model=st.session_state.model_name,
76
- max_completion_tokens=st.session_state.max_completion_tokens,
77
- temperature=st.session_state.temperature,
78
- top_p=st.session_state.top_p,
79
- extra_body={
80
- "repetition_penalty": st.session_state.repetition_penalty,
81
- "top_k": st.session_state.top_k,
82
- "length_penalty": st.session_state.length_penalty
83
- },
84
- seed=st.session_state.seed,
85
- stream=stream
86
  )
87
  except APIConnectionError as e:
88
  if not st.session_state.server.is_running():
@@ -122,13 +113,13 @@ def _validate_input(text_input, array_audio_input) -> List[str]:
122
  return warnings
123
 
124
 
125
- def retrive_response(text_input, array_audio_input, base64_audio_input, stream=False):
126
  warnings = _validate_input(text_input, array_audio_input)
127
 
128
  response_object, error_msg = None, ""
129
  try:
130
  response_object = _retry_retrive_response_throws_exception(
131
- text_input, base64_audio_input, stream
132
  )
133
  except NoAudioException:
134
  error_msg = "Please specify audio first!"
 
64
  )
65
 
66
 
67
+ def _retry_retrive_response_throws_exception(text_input, base64_audio_input, params, stream=False, retry=3):
68
  if not base64_audio_input:
69
  raise NoAudioException("audio is empty.")
70
 
 
72
  response_object = _retrive_response(
73
  text_input=text_input,
74
  base64_audio_input=base64_audio_input,
75
+ stream=stream,
76
+ **params
 
 
 
 
 
 
 
 
 
77
  )
78
  except APIConnectionError as e:
79
  if not st.session_state.server.is_running():
 
113
  return warnings
114
 
115
 
116
+ def retrive_response(text_input, array_audio_input, base64_audio_input, params, stream=False):
117
  warnings = _validate_input(text_input, array_audio_input)
118
 
119
  response_object, error_msg = None, ""
120
  try:
121
  response_object = _retry_retrive_response_throws_exception(
122
+ text_input, base64_audio_input, params, stream
123
  )
124
  except NoAudioException:
125
  error_msg = "Please specify audio first!"
src/logger.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ import json
5
+ from threading import Thread, Lock
6
+
7
+ import streamlit as st
8
+ from huggingface_hub import HfApi
9
+
10
+ from src.utils import get_current_strftime
11
+
12
+
13
+ logger_lock = Lock()
14
+
15
+
16
+ def threaded(fn):
17
+ def wrapper(*args, **kwargs):
18
+ thread = Thread(target=fn, args=args, kwargs=kwargs)
19
+ thread.start()
20
+ return thread
21
+ return wrapper
22
+
23
+
24
+ class Logger:
25
+ def __init__(self):
26
+ self.app_id = get_current_strftime()
27
+ self.session_increment = 0
28
+ self.query_increment = 0
29
+ self.sync_interval = 180
30
+
31
+ self.session_data = []
32
+ self.query_data = []
33
+ self.audio_data = []
34
+
35
+ self.sync_data()
36
+
37
+ def register_session(self) -> str:
38
+ new_session_id = f"{self.app_id}+{self.session_increment}"
39
+ with logger_lock:
40
+ self.session_data.append({
41
+ "session_id": new_session_id,
42
+ "creation_time": get_current_strftime()
43
+ })
44
+
45
+ self.session_increment += 1
46
+ return new_session_id
47
+
48
+ def register_query(self,
49
+ session_id,
50
+ base64_audio,
51
+ text_input,
52
+ params,
53
+ response,
54
+ warnings,
55
+ error_msg
56
+ ):
57
+ new_query_id = self.query_increment
58
+ current_time = get_current_strftime()
59
+
60
+ with logger_lock:
61
+ self.query_data.append({
62
+ "session_id": session_id,
63
+ "query_id": new_query_id,
64
+ "creation_time": current_time,
65
+ "text": text_input,
66
+ "params": params,
67
+ "response": response,
68
+ "warnings": warnings,
69
+ "error": error_msg,
70
+ })
71
+
72
+ self.audio_data.append({
73
+ "session_id": session_id,
74
+ "query_id": new_query_id,
75
+ "creation_time": current_time,
76
+ "audio": base64_audio,
77
+ })
78
+ self.query_increment += 1
79
+
80
+
81
+ @threaded
82
+ def sync_data(self):
83
+ api = HfApi()
84
+
85
+ while True:
86
+ time.sleep(self.sync_interval)
87
+
88
+ for data_name in ["session_data", "query_data", "audio_data"]:
89
+ with logger_lock:
90
+ last_data = getattr(self, data_name, [])
91
+ setattr(self, data_name, [])
92
+
93
+ if not last_data:
94
+ continue
95
+
96
+ buffer = io.BytesIO()
97
+ for row in last_data:
98
+ row_str = json.dumps(row, ensure_ascii=False)+"\n"
99
+ buffer.write(row_str.encode("utf-8"))
100
+
101
+ api.upload_file(
102
+ path_or_fileobj=buffer,
103
+ path_in_repo=f"{data_name}/{get_current_strftime()}.json",
104
+ repo_id=os.getenv("LOGGING_REPO_NAME"),
105
+ repo_type="dataset",
106
+ token=os.getenv('HF_TOKEN')
107
+ )
108
+
109
+
110
+ @st.cache_resource()
111
+ def load_logger():
112
+ return Logger()
src/utils.py CHANGED
@@ -1,9 +1,14 @@
1
  import io
 
2
  from scipy.io.wavfile import write
3
 
4
  import librosa
5
 
6
 
 
 
 
 
7
  def bytes_to_array(audio_bytes):
8
  audio_array, _ = librosa.load(
9
  io.BytesIO(audio_bytes),
 
1
  import io
2
+ from datetime import datetime
3
  from scipy.io.wavfile import write
4
 
5
  import librosa
6
 
7
 
8
+ def get_current_strftime():
9
+ return datetime.now().strftime(r'%d-%m-%y-%H-%M-%S')
10
+
11
+
12
  def bytes_to_array(audio_bytes):
13
  audio_array, _ = librosa.load(
14
  io.BytesIO(audio_bytes),
style/app_style.css CHANGED
@@ -1,7 +1,16 @@
 
 
 
 
 
1
  div[data-testid="stMainBlockContainer"] div[data-testid="stAudioInput"]>div {
2
  max-height: 3rem;
3
  }
4
 
 
 
 
 
5
  div[class="sidebar-intro"] p {
6
  margin-bottom: 0.75rem;
7
  }
@@ -16,8 +25,15 @@ div[data-testid="stChatMessage"]:has(> div[data-testid="stChatMessageAvatarUser"
16
  text-align: right;
17
  }
18
 
 
 
 
 
 
 
19
  div[data-testid="stChatMessage"] div[data-testid="stHorizontalBlock"]:has(> div[data-testid="stColumn"]) {
20
  flex-direction: row-reverse;
 
21
  }
22
 
23
  div[data-testid="stChatMessage"] div[data-testid="stHorizontalBlock"]>div[data-testid="stColumn"]:has( div[data-testid="stButton"]) {
 
1
+ div[data-testid="stMainBlockContainer"] {
2
+ padding-top: 2rem;
3
+ padding-bottom: 1rem;
4
+ }
5
+
6
  div[data-testid="stMainBlockContainer"] div[data-testid="stAudioInput"]>div {
7
  max-height: 3rem;
8
  }
9
 
10
+ div[data-testid="stMainBlockContainer"] h1 {
11
+ padding-top: 0.25rem;
12
+ }
13
+
14
  div[class="sidebar-intro"] p {
15
  margin-bottom: 0.75rem;
16
  }
 
25
  text-align: right;
26
  }
27
 
28
+ /* audio quick actions */
29
+
30
+ div[data-testid="stChatMessage"] div[data-testid="stVerticalBlock"]:has( audio[data-testid="stAudio"]) {
31
+ gap: 2px;
32
+ }
33
+
34
  div[data-testid="stChatMessage"] div[data-testid="stHorizontalBlock"]:has(> div[data-testid="stColumn"]) {
35
  flex-direction: row-reverse;
36
+ gap: 4px;
37
  }
38
 
39
  div[data-testid="stChatMessage"] div[data-testid="stHorizontalBlock"]>div[data-testid="stColumn"]:has( div[data-testid="stButton"]) {
style/normal_window.css CHANGED
@@ -1,6 +1,7 @@
1
  @media(min-width: 576px) {
2
- .stMainBlockContainer {
3
- padding: 2rem 5rem 1rem;
 
4
  }
5
 
6
  div[data-testid="stBottomBlockContainer"] {
 
1
  @media(min-width: 576px) {
2
+ div[data-testid="stMainBlockContainer"] {
3
+ padding-left: 5rem;
4
+ padding-bottom: 5rem;
5
  }
6
 
7
  div[data-testid="stBottomBlockContainer"] {
style/small_window.css CHANGED
@@ -1,4 +1,9 @@
1
  @media(max-width: 576px) {
 
 
 
 
 
2
  div[data-testid="stMainBlockContainer"] div[data-testid="stVerticalBlock"]>div[data-testid="stElementContainer"]:has( div[data-testid="stHeadingWithActionElements"]) {
3
  display: none;
4
  }
@@ -6,4 +11,8 @@
6
  div[class="main-intro-normal-window"] {
7
  display: none;
8
  }
 
 
 
 
9
  }
 
1
  @media(max-width: 576px) {
2
+ div[data-testid="stMainBlockContainer"] {
3
+ padding-left: 1rem;
4
+ padding-bottom: 1rem;
5
+ }
6
+
7
  div[data-testid="stMainBlockContainer"] div[data-testid="stVerticalBlock"]>div[data-testid="stElementContainer"]:has( div[data-testid="stHeadingWithActionElements"]) {
8
  display: none;
9
  }
 
11
  div[class="main-intro-normal-window"] {
12
  display: none;
13
  }
14
+
15
+ div[data-testid="stSidebarCollapsedControl"] button[data-testid="stBaseButton-headerNoPadding"]::after {
16
+ content: "More Use Cases"
17
+ }
18
  }