benjolo commited on
Commit
831f555
1 Parent(s): 0a3ee2a

Fuzzy search fix

Browse files
Files changed (45) hide show
  1. backend/.DS_Store +0 -0
  2. backend/.env +2 -0
  3. backend/.gitignore +2 -0
  4. backend/Client.py +81 -0
  5. backend/__pycache__/Client.cpython-310.pyc +0 -0
  6. backend/__pycache__/main.cpython-310.pyc +0 -0
  7. backend/logging.yaml +22 -0
  8. backend/main.py +343 -0
  9. backend/models/Seamless/vad_s2st_sc_24khz_main.yaml +25 -0
  10. backend/models/SeamlessStreaming/vad_s2st_sc_main.yaml +21 -0
  11. backend/mongodb/endpoints/__pycache__/calls.cpython-310.pyc +0 -0
  12. backend/mongodb/endpoints/__pycache__/users.cpython-310.pyc +0 -0
  13. backend/mongodb/endpoints/calls.py +96 -0
  14. backend/mongodb/endpoints/users.py +53 -0
  15. backend/mongodb/models/__pycache__/calls.cpython-310.pyc +0 -0
  16. backend/mongodb/models/__pycache__/users.cpython-310.pyc +0 -0
  17. backend/mongodb/models/calls.py +75 -0
  18. backend/mongodb/models/users.py +44 -0
  19. backend/mongodb/operations/__pycache__/calls.cpython-310.pyc +0 -0
  20. backend/mongodb/operations/__pycache__/users.cpython-310.pyc +0 -0
  21. backend/mongodb/operations/calls.py +285 -0
  22. backend/mongodb/operations/users.py +77 -0
  23. backend/pcmToWav.py +34 -0
  24. backend/preprocess_wav.py +65 -0
  25. backend/requirements.txt +28 -0
  26. backend/routes/__init__.py +1 -0
  27. backend/routes/__pycache__/__init__.cpython-310.pyc +0 -0
  28. backend/routes/__pycache__/routing.cpython-310.pyc +0 -0
  29. backend/routes/routing.py +9 -0
  30. backend/seamless/__init__.py +0 -0
  31. backend/seamless/room.py +64 -0
  32. backend/seamless/simuleval_agent_directory.py +171 -0
  33. backend/seamless/simuleval_transcoder.py +428 -0
  34. backend/seamless/speech_and_text_output.py +15 -0
  35. backend/seamless/transcoder_helpers.py +43 -0
  36. backend/seamless_utils.py +210 -0
  37. backend/tests/__pycache__/test_client.cpython-310-pytest-8.1.1.pyc +0 -0
  38. backend/tests/__pycache__/test_main.cpython-310-pytest-8.1.1.pyc +0 -0
  39. backend/tests/__pycache__/test_main.cpython-310.pyc +0 -0
  40. backend/tests/silence.wav +0 -0
  41. backend/tests/speaking.wav +0 -0
  42. backend/tests/test_client.py +59 -0
  43. backend/tests/test_main.py +90 -0
  44. backend/utils/__pycache__/text_rank.cpython-310.pyc +0 -0
  45. backend/utils/text_rank.py +60 -0
backend/.DS_Store ADDED
Binary file (6.15 kB). View file
 
backend/.env ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ MONGODB_URI=mongodb+srv://benjolo:26qtppddzz2jx9@it-cluster1.4cwyb2f.mongodb.net/?retryWrites=true&w=majority&appName=IT-Cluster1
2
+ OPENAI_API_KEY=sk-proj-vc4w7s6gkfwFG8xLBunZT3BlbkFJ8h9zOoyS0OY756vMgBcc
backend/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ myenv
2
+ .pytest_cache
backend/Client.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import wave
3
+ import os
4
+
5
+ import torchaudio
6
+ from vad import EnergyVAD
7
+ TARGET_SAMPLING_RATE = 16000
8
+
9
+ def create_frames(data: bytes, frame_duration: int) -> Tuple[bytes]:
10
+ frame_size = int(TARGET_SAMPLING_RATE * (frame_duration / 1000))
11
+ return (data[i:i + frame_size] for i in range(0, len(data), frame_size)), frame_size
12
+
13
+ def detect_activity(energies: list):
14
+ if sum(energies) < len(energies) / 12:
15
+ return False
16
+ count = 0
17
+ for energy in energies:
18
+ if energy == 1:
19
+ count += 1
20
+ if count == 12:
21
+ return True
22
+ else:
23
+ count = 0
24
+ return False
25
+
26
+ class Client:
27
+ def __init__(self, sid, client_id, username, call_id=None, original_sr=None):
28
+ self.sid = sid
29
+ self.client_id = client_id
30
+ self.username = username,
31
+ self.call_id = call_id
32
+ self.buffer = bytearray()
33
+ self.output_path = self.sid + "_output_audio.wav"
34
+ self.target_language = None
35
+ self.original_sr = original_sr
36
+ self.vad = EnergyVAD(
37
+ sample_rate=TARGET_SAMPLING_RATE,
38
+ frame_length=25,
39
+ frame_shift=20,
40
+ energy_threshold=0.05,
41
+ pre_emphasis=0.95,
42
+ ) # PM - Default values given in the docs for this class
43
+
44
+ def add_bytes(self, new_bytes):
45
+ self.buffer += new_bytes
46
+
47
+ def resample_and_clear(self):
48
+ print(f"📥 [ClientAudioBuffer] Writing {len(self.buffer)} bytes to {self.output_path}")
49
+ with wave.open(self.sid + "_OG.wav", "wb") as wf:
50
+ wf.setnchannels(1)
51
+ wf.setsampwidth(2)
52
+ wf.setframerate(self.original_sr)
53
+ wf.setnframes(0)
54
+ wf.setcomptype("NONE", "not compressed")
55
+ wf.writeframes(self.buffer)
56
+ waveform, sample_rate = torchaudio.load(self.sid + "_OG.wav")
57
+ resampler = torchaudio.transforms.Resample(sample_rate, TARGET_SAMPLING_RATE, dtype=waveform.dtype)
58
+ resampled_waveform = resampler(waveform)
59
+ self.buffer = bytearray()
60
+ return resampled_waveform
61
+
62
+ def vad_analyse(self, resampled_waveform):
63
+ torchaudio.save(self.output_path, resampled_waveform, TARGET_SAMPLING_RATE)
64
+ vad_array = self.vad(resampled_waveform)
65
+ print(f"VAD OUTPUT: {vad_array}")
66
+ return detect_activity(vad_array)
67
+
68
+ def write_to_file(self, resampled_waveform):
69
+ torchaudio.save(self.output_path, resampled_waveform, TARGET_SAMPLING_RATE)
70
+
71
+ def get_length(self):
72
+ return len(self.buffer)
73
+
74
+ def __del__(self):
75
+ if len(self.buffer) > 0:
76
+ print(f"🚨 [ClientAudioBuffer] Buffer not empty for {self.sid} ({len(self.buffer)} bytes)!")
77
+ if os.path.exists(self.output_path):
78
+ os.remove(self.output_path)
79
+ if os.path.exists(self.sid + "_OG.wav"):
80
+ os.remove(self.sid + "_OG.wav")
81
+
backend/__pycache__/Client.cpython-310.pyc ADDED
Binary file (3.41 kB). View file
 
backend/__pycache__/main.cpython-310.pyc ADDED
Binary file (6.48 kB). View file
 
backend/logging.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 1
2
+ disable_existing_loggers: false
3
+
4
+ formatters:
5
+ standard:
6
+ format: "%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s"
7
+
8
+ handlers:
9
+ console:
10
+ class: logging.StreamHandler
11
+ formatter: standard
12
+ stream: ext://sys.stdout
13
+
14
+ loggers:
15
+ uvicorn:
16
+ error:
17
+ propagate: true
18
+
19
+ root:
20
+ level: INFO
21
+ handlers: [console]
22
+ propagate: no
backend/main.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from operator import itemgetter
2
+ import os
3
+ from datetime import datetime
4
+ import uvicorn
5
+ from typing import Any, Optional, Tuple, Dict, TypedDict
6
+ from urllib import parse
7
+ from uuid import uuid4
8
+ import logging
9
+ from fastapi.logger import logger as fastapi_logger
10
+ import sys
11
+
12
+ from fastapi import FastAPI
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from fastapi import APIRouter, Body, Request, status
15
+ from pymongo import MongoClient
16
+ from dotenv import dotenv_values
17
+ from routes import router as api_router
18
+ from contextlib import asynccontextmanager
19
+ import requests
20
+
21
+ from typing import List
22
+ from datetime import date
23
+ from mongodb.operations.calls import *
24
+ from mongodb.operations.users import *
25
+ from mongodb.models.calls import UserCall, UpdateCall
26
+ # from mongodb.endpoints.calls import *
27
+
28
+ from transformers import AutoProcessor, SeamlessM4Tv2Model
29
+
30
+ # from seamless_communication.inference import Translator
31
+ from Client import Client
32
+ import numpy as np
33
+ import torch
34
+ import socketio
35
+
36
+ # Configure logger
37
+ gunicorn_error_logger = logging.getLogger("gunicorn.error")
38
+ gunicorn_logger = logging.getLogger("gunicorn")
39
+ uvicorn_access_logger = logging.getLogger("uvicorn.access")
40
+
41
+ gunicorn_error_logger.propagate = True
42
+ gunicorn_logger.propagate = True
43
+ uvicorn_access_logger.propagate = True
44
+
45
+ uvicorn_access_logger.handlers = gunicorn_error_logger.handlers
46
+ fastapi_logger.handlers = gunicorn_error_logger.handlers
47
+
48
+ # sio is the main socket.io entrypoint
49
+ sio = socketio.AsyncServer(
50
+ async_mode="asgi",
51
+ cors_allowed_origins="*",
52
+ logger=gunicorn_logger,
53
+ engineio_logger=gunicorn_logger,
54
+ )
55
+ # sio.logger.setLevel(logging.DEBUG)
56
+ socketio_app = socketio.ASGIApp(sio)
57
+ # app.mount("/", socketio_app)
58
+
59
+ config = dotenv_values(".env")
60
+
61
+ # Read connection string from environment vars
62
+ # uri = os.environ['MONGODB_URI']
63
+
64
+ # Read connection string from .env file
65
+ uri = config['MONGODB_URI']
66
+
67
+
68
+ # MongoDB Connection Lifespan Events
69
+ @asynccontextmanager
70
+ async def lifespan(app: FastAPI):
71
+ # startup logic
72
+ app.mongodb_client = MongoClient(uri)
73
+ app.database = app.mongodb_client['IT-Cluster1'] #connect to interpretalk primary db
74
+ try:
75
+ app.mongodb_client.admin.command('ping')
76
+ print("MongoDB Connection Established...")
77
+ except Exception as e:
78
+ print(e)
79
+
80
+ yield
81
+
82
+ # shutdown logic
83
+ print("Closing MongoDB Connection...")
84
+ app.mongodb_client.close()
85
+
86
+ app = FastAPI(lifespan=lifespan, logger=gunicorn_logger)
87
+
88
+ # New CORS funcitonality
89
+ app.add_middleware(
90
+ CORSMiddleware,
91
+ allow_origins=["*"], # configured node app port
92
+ allow_credentials=True,
93
+ allow_methods=["*"],
94
+ allow_headers=["*"],
95
+ )
96
+
97
+ app.include_router(api_router) # include routers for user, calls and transcripts operations
98
+
99
+ DEBUG = True
100
+
101
+ ESCAPE_HATCH_SERVER_LOCK_RELEASE_NAME = "remove_server_lock"
102
+
103
+ TARGET_SAMPLING_RATE = 16000
104
+ MAX_BYTES_BUFFER = 960_000
105
+
106
+ print("")
107
+ print("")
108
+ print("=" * 18 + " Interpretalk is starting... " + "=" * 18)
109
+
110
+ ###############################################
111
+ # Configure socketio server
112
+ ###############################################
113
+
114
+ # TODO PM - change this to the actual path
115
+ # seamless remnant code
116
+ CLIENT_BUILD_PATH = "../streaming-react-app/dist/"
117
+ static_files = {
118
+ "/": CLIENT_BUILD_PATH,
119
+ "/assets/seamless-db6a2555.svg": {
120
+ "filename": CLIENT_BUILD_PATH + "assets/seamless-db6a2555.svg",
121
+ "content_type": "image/svg+xml",
122
+ },
123
+ }
124
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
125
+ processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
126
+
127
+ # PM - hardcoding temporarily as my GPU doesnt have enough vram
128
+ model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to(device)
129
+
130
+
131
+ bytes_data = bytearray()
132
+ model_name = "seamlessM4T_v2_large"
133
+ vocoder_name = "vocoder_v2" if model_name == "seamlessM4T_v2_large" else "vocoder_36langs"
134
+
135
+ clients = {}
136
+ rooms = {}
137
+
138
+
139
+ def get_collection_users():
140
+ return app.database["user_records"]
141
+
142
+ def get_collection_calls():
143
+ return app.database["call_records"]
144
+
145
+
146
+ @app.get("/home/", response_description="Welcome User")
147
+ def test():
148
+ return {"message": "Welcome to InterpreTalk!"}
149
+
150
+
151
+ async def send_translated_text(client_id, username, original_text, translated_text, room_id):
152
+ # print(rooms) # Debugging
153
+ # print(clients) # Debugging
154
+
155
+ data = {
156
+ "author_id": str(client_id),
157
+ "author_username": str(username),
158
+ "original_text": str(original_text),
159
+ "translated_text": str(translated_text),
160
+ "timestamp": str(datetime.now())
161
+ }
162
+ gunicorn_logger.info("SENDING TRANSLATED TEXT TO CLIENT")
163
+ await sio.emit("translated_text", data, room=room_id)
164
+ gunicorn_logger.info("SUCCESSFULLY SEND AUDIO TO FRONTEND")
165
+
166
+
167
+ @sio.on("connect")
168
+ async def connect(sid, environ):
169
+ print(f"📥 [event: connected] sid={sid}")
170
+ query_params = dict(parse.parse_qsl(environ["QUERY_STRING"]))
171
+
172
+ client_id = query_params.get("client_id")
173
+ gunicorn_logger.info(f"📥 [event: connected] sid={sid}, client_id={client_id}")
174
+
175
+ # get username to Client Object from DB
176
+ username = find_name_from_id(get_collection_users(), client_id)
177
+
178
+ # sid = socketid, client_id = client specific ID ,always the same for same user
179
+ clients[sid] = Client(sid, client_id, username)
180
+ print(clients[sid].username)
181
+ gunicorn_logger.warning(f"Client connected: {sid}")
182
+ gunicorn_logger.warning(clients)
183
+
184
+
185
+ @sio.on("disconnect")
186
+ async def disconnect(sid):
187
+ gunicorn_logger.debug(f"📤 [event: disconnected] sid={sid}")
188
+
189
+ call_id = clients[sid].call_id
190
+ user_id = clients[sid].client_id
191
+ target_language = clients[sid].target_language
192
+
193
+ clients.pop(sid, None)
194
+
195
+ # Perform Key Term Extraction and summarisation
196
+ try:
197
+ # Get combined caption field for call record based on call_id
198
+ key_terms = term_extraction(get_collection_calls(), call_id, user_id, target_language)
199
+
200
+ # Perform summarisation based on target language
201
+ summary_result = summarise(get_collection_calls(), call_id, user_id, target_language)
202
+
203
+ except:
204
+ gunicorn_logger.error(f"📤 [event: term_extraction/summarisation request error] sid={sid}, call={call_id}")
205
+
206
+
207
+ @sio.on("target_language")
208
+ async def target_language(sid, target_lang):
209
+ gunicorn_logger.info(f"📥 [event: target_language] sid={sid}, target_lang={target_lang}")
210
+ clients[sid].target_language = target_lang
211
+
212
+
213
+ @sio.on("call_user")
214
+ async def call_user(sid, call_id):
215
+ clients[sid].call_id = call_id
216
+ gunicorn_logger.info(f"CALL {sid}: entering room {call_id}")
217
+ rooms[call_id] = rooms.get(call_id, [])
218
+ if sid not in rooms[call_id] and len(rooms[call_id]) < 2:
219
+ rooms[call_id].append(sid)
220
+ sio.enter_room(sid, call_id)
221
+ else:
222
+ gunicorn_logger.info(f"CALL {sid}: room {call_id} is full")
223
+ # await sio.emit("room_full", room=call_id, to=sid)
224
+
225
+ # BO - Get call id from dictionary created during socketio connection
226
+ client_id = clients[sid].client_id
227
+
228
+ gunicorn_logger.warning(f"NOW TRYING TO CREATE DB RECORD FOR Caller with ID: {client_id} for call: {call_id}")
229
+ # BO -> Create Call Record with Caller and call_id field (None for callee, duration, terms..)
230
+ request_data = {
231
+ "call_id": str(call_id),
232
+ "caller_id": str(client_id),
233
+ "creation_date": str(datetime.now())
234
+ }
235
+
236
+ response = create_calls(get_collection_calls(), request_data)
237
+ print(response) # BO - print created db call record
238
+
239
+
240
+ @sio.on("audio_config")
241
+ async def audio_config(sid, sample_rate):
242
+ clients[sid].original_sr = sample_rate
243
+
244
+
245
+ @sio.on("answer_call")
246
+ async def answer_call(sid, call_id):
247
+
248
+ clients[sid].call_id = call_id
249
+ gunicorn_logger.info(f"ANSWER {sid}: entering room {call_id}")
250
+ rooms[call_id] = rooms.get(call_id, [])
251
+ if sid not in rooms[call_id] and len(rooms[call_id]) < 2:
252
+ rooms[call_id].append(sid)
253
+ sio.enter_room(sid, call_id)
254
+ else:
255
+ gunicorn_logger.info(f"ANSWER {sid}: room {call_id} is full")
256
+ # await sio.emit("room_full", room=call_id, to=sid)
257
+
258
+
259
+ # BO - Get call id from dictionary created during socketio connection
260
+ client_id = clients[sid].client_id
261
+
262
+ # BO -> Update Call Record with Callee field based on call_id
263
+ gunicorn_logger.warning(f"NOW UPDATING MongoDB RECORD FOR Caller with ID: {client_id} for call: {call_id}")
264
+ # BO -> Create Call Record with callee_id field (None for callee, duration, terms..)
265
+ request_data = {
266
+ "callee_id": client_id
267
+ }
268
+
269
+ response = update_calls(get_collection_calls(), call_id, request_data)
270
+ print(response) # BO - print created db call record
271
+
272
+
273
+ @sio.on("incoming_audio")
274
+ async def incoming_audio(sid, data, call_id):
275
+ try:
276
+ clients[sid].add_bytes(data)
277
+
278
+ if clients[sid].get_length() >= MAX_BYTES_BUFFER:
279
+ gunicorn_logger.info('Buffer full, now outputting...')
280
+ output_path = clients[sid].output_path
281
+ resampled_audio = clients[sid].resample_and_clear()
282
+ vad_result = clients[sid].vad_analyse(resampled_audio)
283
+ # source lang is speakers tgt language 😃
284
+ src_lang = clients[sid].target_language
285
+
286
+ if vad_result:
287
+ gunicorn_logger.info('Speech detected, now processing audio.....')
288
+ tgt_sid = next(id for id in rooms[call_id] if id != sid)
289
+ tgt_lang = clients[tgt_sid].target_language
290
+ # following example from https://github.com/facebookresearch/seamless_communication/blob/main/docs/m4t/README.md#transformers-usage
291
+ output_tokens = processor(audios=resampled_audio, src_lang=src_lang, return_tensors="pt", sampling_rate=TARGET_SAMPLING_RATE).to(device)
292
+ model_output = model.generate(**output_tokens, tgt_lang=src_lang, generate_speech=False)[0].tolist()[0]
293
+ asr_text = processor.decode(model_output, skip_special_tokens=True)
294
+ print(f"ASR TEXT = {asr_text}")
295
+ # ASR TEXT => ORIGINAL TEXT
296
+
297
+ if src_lang != tgt_lang:
298
+ t2t_tokens = processor(text=asr_text, src_lang=src_lang, tgt_lang=tgt_lang, return_tensors="pt").to(device)
299
+ translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0]
300
+ translated_text = processor.decode(translated_data, skip_special_tokens=True)
301
+ print(f"TRANSLATED TEXT = {translated_text}")
302
+ else:
303
+ # PM - both users have same language selected, no need to translate
304
+ translated_text = asr_text
305
+
306
+ # PM - text_output is a list with 1 string
307
+ await send_translated_text(clients[sid].client_id, clients[sid].username, asr_text, translated_text, call_id)
308
+
309
+ # BO -> send translated_text to mongodb as caption record update based on call_id
310
+ await send_captions(clients[sid].client_id, clients[sid].username, asr_text, translated_text, call_id)
311
+
312
+ except Exception as e:
313
+ gunicorn_logger.error(f"Error in incoming_audio: {e.with_traceback()}")
314
+
315
+
316
+ async def send_captions(client_id, username, original_text, translated_text, call_id):
317
+ # BO -> Update Call Record with Callee field based on call_id
318
+ print(f"Now updating Caption field in call record for Caller with ID: {client_id} for call: {call_id}")
319
+
320
+ data = {
321
+ "author_id": str(client_id),
322
+ "author_username": str(username),
323
+ "original_text": str(original_text),
324
+ "translated_text": str(translated_text),
325
+ "timestamp": str(datetime.now())
326
+ }
327
+
328
+ response = update_captions(get_collection_calls(), get_collection_users(), call_id, data)
329
+ return response
330
+
331
+
332
+ app.mount("/", socketio_app)
333
+
334
+
335
+ if __name__ == '__main__':
336
+ uvicorn.run("main:app", host='0.0.0.0', port=7860, log_level="info")
337
+
338
+
339
+ # Running in Docker Container
340
+ if __name__ != "__main__":
341
+ fastapi_logger.setLevel(gunicorn_logger.level)
342
+ else:
343
+ fastapi_logger.setLevel(logging.DEBUG)
backend/models/Seamless/vad_s2st_sc_24khz_main.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ agent_class: seamless_communication.streaming.agents.seamless_s2st.SeamlessS2STDualVocoderVADAgent
2
+ monotonic_decoder_model_name: seamless_streaming_monotonic_decoder
3
+ unity_model_name: seamless_streaming_unity
4
+ sentencepiece_model: spm_256k_nllb100.model
5
+
6
+ task: s2st
7
+ tgt_lang: "eng"
8
+ min_unit_chunk_size: 50
9
+ decision_threshold: 0.7
10
+ no_early_stop: True
11
+ block_ngrams: True
12
+ vocoder_name: vocoder_v2
13
+ expr_vocoder_name: vocoder_pretssel
14
+ gated_model_dir: .
15
+ expr_vocoder_gain: 3.0
16
+ upstream_idx: 1
17
+ wav2vec_yaml: wav2vec.yaml
18
+ min_starting_wait_w2vbert: 192
19
+
20
+ config_yaml: cfg_fbank_u2t.yaml
21
+ upstream_idx: 1
22
+ detokenize_only: True
23
+ device: cuda:0
24
+ max_len_a: 0
25
+ max_len_b: 1000
backend/models/SeamlessStreaming/vad_s2st_sc_main.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ agent_class: seamless_communication.streaming.agents.seamless_streaming_s2st.SeamlessStreamingS2STJointVADAgent
2
+ monotonic_decoder_model_name: seamless_streaming_monotonic_decoder
3
+ unity_model_name: seamless_streaming_unity
4
+ sentencepiece_model: spm_256k_nllb100.model
5
+
6
+ task: s2st
7
+ tgt_lang: "eng"
8
+ min_unit_chunk_size: 50
9
+ decision_threshold: 0.7
10
+ no_early_stop: True
11
+ block_ngrams: True
12
+ vocoder_name: vocoder_v2
13
+ wav2vec_yaml: wav2vec.yaml
14
+ min_starting_wait_w2vbert: 192
15
+
16
+ config_yaml: cfg_fbank_u2t.yaml
17
+ upstream_idx: 1
18
+ detokenize_only: True
19
+ device: cuda:0
20
+ max_len_a: 0
21
+ max_len_b: 1000
backend/mongodb/endpoints/__pycache__/calls.cpython-310.pyc ADDED
Binary file (4.74 kB). View file
 
backend/mongodb/endpoints/__pycache__/users.cpython-310.pyc ADDED
Binary file (2.43 kB). View file
 
backend/mongodb/endpoints/calls.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Body, Request, status, HTTPException
2
+ from typing import List
3
+ from datetime import date
4
+
5
+ import sys
6
+
7
+ from ..operations import calls as calls
8
+ from ..models.calls import UserCaptions, UserCall, UpdateCall
9
+ from ..endpoints.users import get_collection_users
10
+
11
+ router = APIRouter(prefix="/call",
12
+ tags=["Calls"])
13
+
14
+ def get_collection_calls(request: Request):
15
+ try:
16
+ return request.app.database["call_records"]
17
+ # return request.app.database["call_test"]
18
+ except:
19
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Unable to find call records Database.")
20
+
21
+ @router.post("/create-call", response_description="Create a new user call record", status_code=status.HTTP_201_CREATED, response_model=UserCall)
22
+ async def create_calls(request: Request, user_calls: UserCall = Body(...)):
23
+ collection = get_collection_calls(request)
24
+ return calls.create_calls(collection, user_calls)
25
+
26
+ @router.get("/list-call", response_description="List all existing call records", response_model=List[UserCall])
27
+ async def list_calls(request: Request, limit: int):
28
+ collection = get_collection_calls(request)
29
+ return calls.list_calls(collection, 100)
30
+
31
+ @router.get("/find-call/{call_id}", response_description="Find user's calls based on User ID", response_model=UserCall)
32
+ async def find_call(request: Request, call_id: str):
33
+ collection = get_collection_calls(request)
34
+ return calls.find_call(collection, call_id)
35
+
36
+ @router.get("/find-user-calls/{user_id}", response_description="Find user's calls based on User ID", response_model=List[UserCall])
37
+ async def find_user_calls(request: Request, user_id: str):
38
+ collection = get_collection_calls(request)
39
+ return calls.find_user_calls(collection, user_id)
40
+
41
+ @router.get("/get-captions/{user_id}", response_description="Find user's calls based on User ID")
42
+ async def get_caption_text(request: Request, call_id: str, user_id: str):
43
+ collection = get_collection_calls(request)
44
+ return calls.get_caption_text(collection, call_id, user_id)
45
+
46
+ '''Key terms list can have variable length -> using POST request over GET'''
47
+ @router.post("/find-term/", response_description="Find calls based on key term list", response_model=List[UserCall])
48
+ async def list_transcripts_by_key_terms(request: Request, key_terms: List[str]):
49
+ collection = get_collection_calls(request)
50
+ return calls.list_transcripts_by_key_terms(collection, key_terms)
51
+
52
+ @router.get("/find-date/{start_date}/{end_date}", response_description="Find calls based on date ranges", response_model=List[UserCall])
53
+ async def list_transcripts_by_dates(request: Request, start_date: str, end_date: str):
54
+ collection = get_collection_calls(request)
55
+ return calls.list_transcripts_by_dates(collection, start_date, end_date)
56
+
57
+ @router.get("/find-duration/{min_len}/{max_len}", response_description="Find calls based on call duration in minutes", response_model=List[UserCall])
58
+ async def list_transcripts_by_duration(request: Request, min_len: int, max_len: int):
59
+ collection = get_collection_calls(request)
60
+ return calls.list_transcripts_by_duration(collection, min_len, max_len)
61
+
62
+ @router.put("/update-call/{call_id}", response_description="Update an existing call", response_model=UpdateCall)
63
+ async def update_calls(request: Request, call_id: str, user_calls: UpdateCall = Body(...)):
64
+ collection = get_collection_calls(request)
65
+ return calls.update_calls(collection, call_id, user_calls)
66
+
67
+ @router.put("/update-captions/{call_id}", response_description="Update an existing call", response_model=UpdateCall)
68
+ async def update_captions(request: Request, call_id: str, user_calls: UserCaptions = Body(...)):
69
+ call_collection = get_collection_calls(request)
70
+ user_collection = get_collection_users(request)
71
+ return calls.update_captions(call_collection, user_collection, call_id, user_calls)
72
+
73
+ @router.delete("/delete-call/{call_id}", response_description="Delete a call by its id")
74
+ async def delete_call(request: Request, call_id: str):
75
+ collection = get_collection_calls(request)
76
+ return calls.delete_calls(collection, call_id)
77
+
78
+ @router.get("/full-text-search/{query}", response_description="Perform full text search on caption fields", response_model=List[UserCall])
79
+ async def full_text_search(request: Request, query: str):
80
+ collection = get_collection_calls(request)
81
+ return calls.full_text_search(collection, query)
82
+
83
+ @router.get("/fuzzy-search/{user_id}/{query}", response_description="Perform fuzzy text search on caption fields", response_model=List[UserCall])
84
+ async def fuzzy_search(request: Request, user_id: str, query: str):
85
+ collection = get_collection_calls(request)
86
+ return calls.fuzzy_search(collection, user_id, query)
87
+
88
+ @router.get("/summarise/{call_id}/{user_id}/{target_language}", response_description="Perform gpt-3.5 summarisation on call_id")
89
+ async def summarise(request: Request, call_id: str, user_id: str, target_language: str):
90
+ collection = get_collection_calls(request)
91
+ return calls.summarise(collection, call_id, user_id, target_language)
92
+
93
+ @router.get("/term-extraction/{call_id}/{user_id}/{target_language}", response_description="Perform key term extraction on call record")
94
+ async def term_extraction(request: Request, call_id: str, user_id: str, target_language: str):
95
+ collection = get_collection_calls(request)
96
+ return calls.term_extraction(collection, call_id, user_id, target_language)
backend/mongodb/endpoints/users.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Body, Request, status, HTTPException
2
+ from typing import List
3
+ import sys
4
+ from ..models.users import User, UpdateUser
5
+ from ..operations import users as users
6
+
7
+ router = APIRouter(prefix="/user",
8
+ tags=["User"])
9
+
10
+ def get_collection_users(request: Request):
11
+ db = request.app.database["user_records"]
12
+ return db
13
+
14
+ @router.post("/", response_description="Create a new user", status_code=status.HTTP_201_CREATED, response_model=User)
15
+ async def create_user(request: Request, user: User = Body(...)):
16
+ collection = get_collection_users(request)
17
+ return users.create_user(collection, user)
18
+
19
+ @router.get("/", response_description="List users", response_model=List[User])
20
+ async def list_users(request: Request):
21
+ collection = get_collection_users(request)
22
+ return users.list_users(collection, 100)
23
+
24
+ @router.put("/{user_id}", response_description="Update a User", response_model=UpdateUser)
25
+ async def update_user(request: Request, user_id: str, user: UpdateUser = Body(...)):
26
+ collection = get_collection_users(request)
27
+ return users.update_user(collection, user_id, user)
28
+
29
+ @router.get("/{user_id}", response_description="Get a single user by id", response_model=User)
30
+ async def find_user(request: Request, user_id: str):
31
+ collection = get_collection_users(request)
32
+ return users.find_user(collection, user_id)
33
+
34
+ @router.get("/find-name-id/{user_id}", response_description="Get a username from user id")
35
+ async def find_name_from_id(request: Request, user_id: str):
36
+ collection = get_collection_users(request)
37
+ return users.find_name_from_id(collection, user_id)
38
+
39
+ @router.get("/name/{user_name}", response_description="Get a single user by name", response_model=User)
40
+ async def find_user_name(request: Request, name: str):
41
+ collection = get_collection_users(request)
42
+ return users.find_user_name(collection, name)
43
+
44
+ @router.get("/email/{email_addr}", response_description="Get a single user by email", response_model=User)
45
+ async def find_user_email(request: Request, email: str):
46
+ collection = get_collection_users(request)
47
+ return users.find_user_email(collection, email)
48
+
49
+ @router.delete("/{user_id}", response_description="Delete a user")
50
+ async def delete_user(request: Request, user_id:str):
51
+ collection = get_collection_users(request)
52
+ return users.delete_user(collection, user_id)
53
+
backend/mongodb/models/__pycache__/calls.cpython-310.pyc ADDED
Binary file (3.09 kB). View file
 
backend/mongodb/models/__pycache__/users.cpython-310.pyc ADDED
Binary file (1.73 kB). View file
 
backend/mongodb/models/calls.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from typing import List, Dict, Optional
3
+ from datetime import datetime
4
+ from pydantic import BaseModel, Field, PrivateAttr
5
+ import sys
6
+
7
+
8
+ ''' Class for storing captions generated by SeamlessM4T'''
9
+ class UserCaptions(BaseModel):
10
+ _id: uuid.UUID = PrivateAttr(default_factory=uuid.uuid4) # private attr not included in http calls
11
+ author_id: Optional[str] = None
12
+ author_username: Optional[str] = None
13
+ original_text: str
14
+ translated_text: str
15
+ timestamp: datetime = Field(default_factory=datetime.now)
16
+
17
+ class Config:
18
+ populate_by_name = True
19
+ json_schema_extra = {
20
+ "example": {
21
+ "author_id": "gLZrfTwXyLUPB3eT7xT2HZnZiZT2",
22
+ "author_username": "shamzino",
23
+ "original_text": "eng: This is original_text english text",
24
+ "translated_text": "spa: este es el texto traducido al español",
25
+ "timestamp": "2024-03-28T16:15:50.956055",
26
+
27
+ }
28
+ }
29
+
30
+
31
+ '''Class for storing past call records from users'''
32
+ class UserCall(BaseModel):
33
+ _id: uuid.UUID = PrivateAttr(default_factory=uuid.uuid4)
34
+ call_id: Optional[str] = None
35
+ caller_id: Optional[str] = None
36
+ callee_id: Optional[str] = None
37
+ creation_date: datetime = Field(default_factory=datetime.now, alias="date")
38
+ duration: Optional[int] = None # milliseconds
39
+ captions: Optional[List[UserCaptions]] = None
40
+ key_terms: Optional[dict] = None
41
+ summaries: Optional[dict] = None
42
+
43
+
44
+ class Config:
45
+ populate_by_name = True
46
+ json_schema_extra = {
47
+ "example": {
48
+ "call_id": "65eef930e9abd3b1e3506906",
49
+ "caller_id": "65ede65b6d246e52aaba9d4f",
50
+ "callee_id": "65edda944340ac84c1f00758",
51
+ "duration": 360,
52
+ "captions": [{"author_id": "gLZrfTwXyLUPB3eT7xT2HZnZiZT2", "author_username": "shamzino", "original_text": "eng: This is original_text english text", "translated_text": "spa: este es el texto traducido al español", "timestamp": "2024-03-28T16:15:50.956055"},
53
+ {"author_id": "g7pR1qCibzQf5mDP9dGtcoWeEc92", "author_username": "benjino", "original_text": "eng: This is source english text", "translated_text": "spa: este es el texto fuente al español", "timestamp": "2024-03-28T16:16:20.34625"}],
54
+ "key_terms": {"gLZrfTwXyLUPB3eT7xT2HZnZiZT2": ["original_text", "source", "english", "text"], "g7pR1qCibzQf5mDP9dGtcoWeEc92": ["translated_text", "destination", "spanish", "text"]},
55
+ "summaries": {"gLZrfTwXyLUPB3eT7xT2HZnZiZT2": "This is a short test on lanuguage translation", "65edda944340ac84c1f00758": "Esta es una breve prueba sobre traducción de idiomas."}
56
+ }
57
+ }
58
+
59
+
60
+ ''' Class for updating User Call record'''
61
+ class UpdateCall(BaseModel):
62
+ call_id: Optional[str] = None
63
+ caller_id: Optional[str] = None
64
+ callee_id: Optional[str] = None
65
+ duration: Optional[int] = None
66
+ captions: Optional[List[UserCaptions]] = None
67
+ key_terms: Optional[List[str]] = None
68
+
69
+ class Config:
70
+ populate_by_name = True
71
+ json_schema_extra = {
72
+ "example": {
73
+ "duration": "500"
74
+ }
75
+ }
backend/mongodb/models/users.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from typing import List, Optional
3
+ from pydantic import BaseModel, Field, SecretStr, PrivateAttr
4
+ from pydantic.networks import EmailStr
5
+
6
+
7
+ '''Class for user model used to relate users to past calls'''
8
+ class User(BaseModel):
9
+ _id: uuid.UUID = PrivateAttr(default_factory=uuid.uuid4) # private attr not included in http calls
10
+ user_id: str
11
+ name: str
12
+ email: EmailStr = Field(unique=True, index=True)
13
+ # password: SecretStr
14
+ call_ids: Optional[List[str]] = None
15
+
16
+ class Config:
17
+ populate_by_name = True
18
+ json_schema_extra = {
19
+ "example": {
20
+ "user_id": "65ede65b6d246e52aaba9d4f",
21
+ "name": "benjolo",
22
+ "email": "benjolounchained@gmail.com",
23
+ "call_ids": ["65e205ced1be3a22854ff300", "65df8c3eba9c7c2ed1b20e85"]
24
+ }
25
+ }
26
+
27
+ '''Class for updating user records'''
28
+ class UpdateUser(BaseModel):
29
+ user_id: Optional[str] = None
30
+ name: Optional[str] = None
31
+ email: Optional[EmailStr] = None
32
+ ''' To decode use -> SecretStr("abc").get_secret_value()'''
33
+ # password: Optional[SecretStr]
34
+ call_ids: Optional[List[str]] = None
35
+
36
+ class Config:
37
+ populate_by_name = True
38
+ json_schema_extra = {
39
+ "example": {
40
+ "email": "benjolounchained21@gmail.com",
41
+ "call_ids": ["65e205ced1be3a22854ff300", "65df8c3eba9c7c2ed1b20e85", "65eef930e9abd3b1e3506906"]
42
+ }
43
+ }
44
+
backend/mongodb/operations/__pycache__/calls.cpython-310.pyc ADDED
Binary file (6.61 kB). View file
 
backend/mongodb/operations/__pycache__/users.cpython-310.pyc ADDED
Binary file (2.93 kB). View file
 
backend/mongodb/operations/calls.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Body, Request, HTTPException, status
2
+ from fastapi.encoders import jsonable_encoder
3
+ import sys
4
+ from ..models.calls import UpdateCall, UserCall, UserCaptions
5
+ from ..operations.users import *
6
+ from utils.text_rank import extract_terms
7
+ from openai import OpenAI
8
+
9
+ from time import sleep
10
+ import os
11
+ from dotenv import dotenv_values
12
+
13
+ # Used within calls to create call record in main.py
14
+ def create_calls(collection, user: UserCall = Body(...)):
15
+ calls = jsonable_encoder(user)
16
+ new_calls = collection.insert_one(calls)
17
+ created_calls = collection.find_one({"_id": new_calls.inserted_id})
18
+
19
+ return created_calls
20
+
21
+ def list_calls(collection, limit: int):
22
+ try:
23
+ calls = collection.find(limit = limit)
24
+ return list(calls)
25
+ except:
26
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"No existing call records yet.")
27
+
28
+
29
+ '''Finding calls based on call id'''
30
+ def find_call(collection, call_id: str):
31
+ user_calls = collection.find_one({"call_id": call_id})
32
+ if user_calls is not None:
33
+ return user_calls
34
+ else:
35
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call with ID: '{call_id}' not found.")
36
+
37
+
38
+ '''Finding calls based on user id'''
39
+ def find_user_calls(collection, user_id: str):
40
+ user_calls = list(collection.find({"$or": [{"caller_id": user_id}, {"callee_id": user_id}]})) # match on caller or callee ID
41
+ if len(user_calls):
42
+ return user_calls
43
+ else:
44
+ return [] # return empty list if no existing calls for TranscriptView frontend component
45
+
46
+
47
+ '''Finding calls based on key terms list'''
48
+ def list_transcripts_by_key_terms(collection, key_terms_list: list[str] = Body(...)):
49
+ key_terms_list = jsonable_encoder(key_terms_list)
50
+
51
+ call_records = list(collection.find({"key_terms": {"$in": key_terms_list}}, {'_id': 0})) # exclude returning ObjectID in find()
52
+
53
+ # Check if any call records were returned
54
+ if len(call_records):
55
+ return call_records
56
+ else:
57
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call with key terms: '{key_terms_list}' not found!")
58
+
59
+
60
+ '''Finding calls based on date ranges'''
61
+ def list_transcripts_by_dates(collection, start_date: str, end_date: str):
62
+ # print(start_date, end_date)
63
+
64
+ # Convert strings to date string in YYYY-MM-ddT00:00:00 format
65
+ start_date = f'{start_date}T00:00:00'
66
+ end_date = f'{end_date}T00:00:00'
67
+
68
+ call_records = list(collection.find({"date":{"$gte": start_date, "$lte": end_date}}))
69
+
70
+ if len(call_records):
71
+ return call_records
72
+ else:
73
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call with creation date between: '{start_date} - {end_date}' not found!")
74
+
75
+
76
+ '''Finding calls based on call lengths'''
77
+ def list_transcripts_by_duration(collection, min_len: int, max_len: int):
78
+
79
+ call_records = list(collection.find({"duration":{"$gte": min_len, "$lte": max_len}}))
80
+
81
+ if len(call_records):
82
+ return call_records
83
+ else:
84
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call with duration between: '{min_len} - {max_len}' milliseconds not found!")
85
+
86
+
87
+ def update_calls(collection, call_id: str, calls: UpdateCall = Body(...)):
88
+ # calls = {k: v for k, v in calls.model_dump().items() if v is not None} #loop in the dict
89
+ calls = {k: v for k, v in calls.items() if v is not None} #loop in the dict
90
+ print(calls)
91
+
92
+ if len(calls) >= 1:
93
+ update_result = collection.update_one({"call_id": call_id}, {"$set": calls})
94
+
95
+ if update_result.modified_count == 0:
96
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call not updated!")
97
+
98
+ if (existing_item := collection.find_one({"call_id": call_id})) is not None:
99
+ return existing_item
100
+
101
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call not found!")
102
+
103
+
104
+ def update_captions(call_collection, user_collection, call_id: str, captions: UserCaptions = Body(...)):
105
+ # captions = {k: v for k, v in calls.model_dump().items() if v is not None}
106
+ captions = {k: v for k, v in captions.items() if v is not None}
107
+ # print(captions)
108
+
109
+ # index user_id from caption object
110
+ userID = captions["author_id"]
111
+ # print(userID)
112
+
113
+ # use user id to get user name
114
+ username = find_name_from_id(user_collection, userID)
115
+ # print(username)
116
+
117
+ # add user name to captions json/object
118
+ captions["author_username"] = username
119
+ # print(captions)
120
+
121
+ if len(captions) >= 1:
122
+ update_result = call_collection.update_one({"call_id": call_id},
123
+ {"$push": {"captions": captions}})
124
+
125
+ if update_result.modified_count == 0:
126
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Captions not updated!")
127
+
128
+ if (existing_item := call_collection.find_one({"call_id": call_id})) is not None:
129
+ return existing_item
130
+
131
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Captions not found!")
132
+
133
+
134
+ def delete_calls(collection, call_id: str):
135
+ deleted_calls = collection.delete_one({"call_id": call_id})
136
+
137
+ if deleted_calls.deleted_count == 1:
138
+ return f"Call deleted sucessfully!"
139
+
140
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call not found!")
141
+
142
+
143
+ # def get_caption_text(collection, call_id):
144
+ # call_record = find_call((collection), call_id)
145
+
146
+ # try: # Check if call has any captions first
147
+ # caption_records = call_record['captions']
148
+ # except KeyError:
149
+ # return None
150
+
151
+ # # iterate through caption embedded document and store original text
152
+ # combined_text = [caption['original_text'] for caption in caption_records]
153
+
154
+ # return " ".join(combined_text)
155
+
156
+ def get_caption_text(collection, call_id, user_id):
157
+ call_record = find_call((collection), call_id)
158
+
159
+ try: # Check if call has any captions first
160
+ caption_records = call_record['captions']
161
+ except KeyError:
162
+ return None
163
+
164
+ # iterate through caption embedded document and store original text
165
+ # combined_text = [caption['original_text'] for caption in caption_records]
166
+
167
+ combined_text = []
168
+
169
+ for caption_segment in caption_records:
170
+ if caption_segment['author_id'] == user_id:
171
+ combined_text.append(caption_segment['original_text'])
172
+ else:
173
+ combined_text.append(caption_segment['translated_text'])
174
+
175
+ return " ".join(combined_text)
176
+
177
+
178
+ # standard exact match based full text search
179
+ def full_text_search(collection, query):
180
+
181
+ # drop any existing indexes and create new one
182
+ collection.drop_indexes()
183
+ collection.create_index([('captions.original_text', 'text'), ('captions.tranlated_text', 'text')],
184
+ name='captions')
185
+
186
+ # print(collection.index_information())
187
+
188
+ results = list(collection.find({"$text": {"$search": query}}))
189
+ return results
190
+
191
+ # approximate string matching
192
+ def fuzzy_search(collection, user_id, query):
193
+
194
+ user_calls = collection.find({"$or": [{"caller_id": user_id}, {"callee_id": user_id}]})
195
+ print("USER CALLS:", user_calls)
196
+
197
+ print("COLLECTION:", collection)
198
+
199
+ # drop any existing indexes and create new one
200
+ collection.drop_indexes()
201
+ collection.create_index([('captions.original_text', 'text'), ('captions.tranlated_text', 'text')],
202
+ name='captions')
203
+
204
+
205
+ pipeline = [
206
+ {
207
+ "$search": {
208
+ "text": {
209
+ "query": query,
210
+ "path": {"wildcard": "*"},
211
+ "fuzzy": {}
212
+ }
213
+ }
214
+ }
215
+ ]
216
+
217
+ collection_results = list(collection.aggregate(pipeline))
218
+
219
+ # add all users records to output
220
+ records = []
221
+
222
+ for doc in collection_results:
223
+ if doc['caller_id'] == user_id or doc['callee_id'] == user_id:
224
+ records.append(doc)
225
+
226
+ print(records)
227
+
228
+ return records
229
+
230
+
231
+ def summarise(collection, call_id, user_id, target_language):
232
+ # client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
233
+
234
+ config = dotenv_values(".env")
235
+ client = OpenAI(api_key=config["OPENAI_API_KEY"])
236
+
237
+ # get caption text using call_id
238
+ caption_text = get_caption_text(collection, call_id, user_id)
239
+
240
+ chat_completion = client.chat.completions.create(
241
+ messages=[
242
+ {
243
+ "role": "user",
244
+ "content": f"The following is an extract from a call transcript. Rewrite this as a structured, clear summary in {target_language}. \
245
+ \n\Call Transcript: \"\"\"\n{caption_text}\n\"\"\"\n"
246
+ }
247
+ ],
248
+ model="gpt-3.5-turbo",
249
+ )
250
+
251
+ # Gpt-3.5 turbo has 4096 token limit -> request will fail if exceeded
252
+ try:
253
+ result = chat_completion.choices[0].message.content.split(":")[1].strip() # parse summary
254
+ except:
255
+ return None
256
+
257
+ # BO - add result to mongodb -> should be done asynchronously
258
+ # summary_payload = {"summaries": {user_id: result}}
259
+
260
+ update_result = collection.update_one({"call_id": call_id}, {"$set": {f"summaries.{user_id}": result}})
261
+
262
+ if update_result.modified_count == 0:
263
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call not updated!")
264
+
265
+ # if (existing_item := collection.find_one({"call_id": call_id})) is not None:
266
+ # print(existing_item)
267
+
268
+ return result
269
+
270
+
271
+ def term_extraction(collection, call_id, user_id, target_language):
272
+
273
+ combined_text = get_caption_text(collection, call_id, user_id)
274
+
275
+ if len(combined_text) > 50: # > min_caption_length: -> poor term extraction on short transcripts
276
+
277
+ # Extract Key Terms from Concatenated Caption Field
278
+ key_terms = extract_terms(combined_text, target_language, len(combined_text))
279
+
280
+ update_result = collection.update_one({"call_id": call_id}, {"$set": {f"key_terms.{user_id}": key_terms}})
281
+
282
+ if update_result.modified_count == 0:
283
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call not updated!")
284
+
285
+ return key_terms
backend/mongodb/operations/users.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Body, Request, HTTPException, status
2
+ from fastapi.encoders import jsonable_encoder
3
+ import sys
4
+ from ..models.users import User, UpdateUser
5
+ from bson import ObjectId
6
+ import re
7
+
8
+
9
+ def create_user(collection, user: User = Body(...)):
10
+ user = jsonable_encoder(user)
11
+ new_user = collection.insert_one(user)
12
+ created_user = collection.find_one({"_id": new_user.inserted_id})
13
+ print("NEW ID IS:.........", new_user.inserted_id)
14
+ return created_user
15
+
16
+
17
+ def list_users(collection, limit: int):
18
+ try:
19
+ users = list(collection.find(limit = limit))
20
+ return users
21
+ except:
22
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"No users found!")
23
+
24
+
25
+ def find_user(collection, user_id: str):
26
+ if (user := collection.find_one({"user_id": user_id})):
27
+ return user
28
+ else:
29
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with user_id {user_id} not found!")
30
+
31
+ def find_name_from_id(collection, user_id: str):
32
+
33
+ # find_one user record based on user id and project for user name
34
+ if (user_name := collection.find_one({"user_id": user_id}, {"name": 1, "_id": 0})):
35
+ return user_name['name'] # index name field from single field record returned
36
+ else:
37
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with user_id {user_id} not found!")
38
+
39
+ def find_user_name(collection, name: str):
40
+ # search for name in lowercase
41
+ if (user := collection.find_one({"name": re.compile('^' + re.escape(name) + '$', re.IGNORECASE)})):
42
+ return user
43
+ else:
44
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with name {name} not found!")
45
+
46
+
47
+ def find_user_email(collection, email: str):
48
+ if (user := collection.find_one({"email": re.compile('^' + re.escape(email) + '$', re.IGNORECASE)})):
49
+ return user
50
+ else:
51
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with Email Address {email} not found!")
52
+
53
+
54
+ ''' Update user record based on user object/json'''
55
+ def update_user(collection, user_id: str, user: UpdateUser):
56
+ try:
57
+ user = {k: v for k, v in user.model_dump().items() if v is not None}
58
+ if len(user) >= 1:
59
+ update_result = collection.update_one({"user_id": user_id}, {"$set": user})
60
+
61
+ if update_result.modified_count == 0:
62
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with user_id: '{user_id}' not found and updated!")
63
+
64
+ if (existing_users := collection.find_one({"user_id": user_id})) is not None:
65
+ return existing_users
66
+ except:
67
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with user_id: '{user_id}' not found and updated!")
68
+
69
+
70
+ def delete_user(collection, user_id: str):
71
+ try:
72
+ deleted_user = collection.delete_one({"user_id": user_id})
73
+
74
+ if deleted_user.deleted_count == 1:
75
+ return f"User with user_id {user_id} deleted sucessfully"
76
+ except:
77
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with user_id {user_id} not found!")
backend/pcmToWav.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wave
2
+ import os
3
+
4
+
5
+ basePath = os.path.expanduser("~/Desktop/")
6
+
7
+
8
+ def convert_pcm_to_wav():
9
+ # PCM file parameters (should match the parameters used to create the PCM file)
10
+ pcm_file = basePath + 'output.pcm'
11
+ wav_file = 'pcmconverted.wav'
12
+ sample_rate = 16000 # Example: 16000 Hz
13
+ channels = 1 # Example: 2 for stereo
14
+ sample_width = 2 # Example: 2 bytes (16 bits), change if your PCM format is different
15
+
16
+ # Read the PCM file and write to a WAV file
17
+ with open(pcm_file, 'rb') as pcmfile:
18
+ pcm_data = pcmfile.read()
19
+
20
+ with wave.open(wav_file, 'wb') as wavfile:
21
+ wavfile.setnchannels(channels)
22
+ wavfile.setsampwidth(sample_width)
23
+ wavfile.setframerate(sample_rate)
24
+ wavfile.writeframes(pcm_data)
25
+
26
+ convert_pcm_to_wav()
27
+
28
+ # def generateCaptions(filepath):
29
+
30
+ # ! This might be redundant due to seamless-streaming
31
+
32
+
33
+
34
+ print(f"Converted {pcm_file} to {wav_file}")
backend/preprocess_wav.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import soundfile
2
+ import io
3
+ from typing import Any, Tuple, Union, Optional
4
+ import numpy as np
5
+ import torch
6
+
7
+ def preprocess_wav(data: Any, incoming_sample_rate) -> Tuple[np.ndarray, int]:
8
+ segment, sample_rate = soundfile.read(
9
+ io.BytesIO(data),
10
+ dtype="float32",
11
+ always_2d=True,
12
+ frames=-1,
13
+ start=0,
14
+ format="RAW",
15
+ subtype="PCM_16",
16
+ samplerate=incoming_sample_rate,
17
+ channels=1,
18
+ )
19
+ return segment, sample_rate
20
+
21
+ def convert_waveform(
22
+ waveform: Union[np.ndarray, torch.Tensor],
23
+ sample_rate: int,
24
+ normalize_volume: bool = False,
25
+ to_mono: bool = False,
26
+ to_sample_rate: Optional[int] = None,
27
+ ) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
28
+ """convert a waveform:
29
+ - to a target sample rate
30
+ - from multi-channel to mono channel
31
+ - volume normalization
32
+
33
+ Args:
34
+ waveform (numpy.ndarray or torch.Tensor): 2D original waveform
35
+ (channels x length)
36
+ sample_rate (int): original sample rate
37
+ normalize_volume (bool): perform volume normalization
38
+ to_mono (bool): convert to mono channel if having multiple channels
39
+ to_sample_rate (Optional[int]): target sample rate
40
+ Returns:
41
+ waveform (numpy.ndarray): converted 2D waveform (channels x length)
42
+ sample_rate (float): target sample rate
43
+ """
44
+ try:
45
+ import torchaudio.sox_effects as ta_sox
46
+ except ImportError:
47
+ raise ImportError("Please install torchaudio: pip install torchaudio")
48
+
49
+ effects = []
50
+ if normalize_volume:
51
+ effects.append(["gain", "-n"])
52
+ if to_sample_rate is not None and to_sample_rate != sample_rate:
53
+ effects.append(["rate", f"{to_sample_rate}"])
54
+ if to_mono and waveform.shape[0] > 1:
55
+ effects.append(["channels", "1"])
56
+ if len(effects) > 0:
57
+ is_np_input = isinstance(waveform, np.ndarray)
58
+ _waveform = torch.from_numpy(waveform) if is_np_input else waveform
59
+ converted, converted_sample_rate = ta_sox.apply_effects_tensor(
60
+ _waveform, sample_rate, effects
61
+ )
62
+ if is_np_input:
63
+ converted = converted.numpy()
64
+ return converted, converted_sample_rate
65
+ return waveform, sample_rate
backend/requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ colorlog==6.8.2
2
+ contextlib2==21.6.0
3
+ fastapi==0.110.1
4
+ g2p_en==2.1.0
5
+ matplotlib==3.7.0
6
+ numpy==1.24.2
7
+ openai==1.20.0
8
+ protobuf==5.26.1
9
+ pydantic==2.7.0
10
+ pydub==0.25.1
11
+ pymongo==4.6.2
12
+ PySoundFile==0.9.0.post1
13
+ python-dotenv==1.0.1
14
+ python-socketio==5.9.0
15
+ pymongo==4.6.2
16
+ Requests==2.31.0
17
+ sentencepiece==0.1.99
18
+ simuleval==1.1.4
19
+ soundfile==0.12.1
20
+ spacy==3.7.4
21
+ pytextrank==3.3.0
22
+ torch==2.1.2
23
+ torchaudio==2.1.2
24
+ #transformers==4.20.1
25
+ uvicorn==0.29.0
26
+ vad==1.0.2
27
+ hf_transfer==0.1.4
28
+ huggingface_hub==0.19.4
backend/routes/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from.routing import router
backend/routes/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (235 Bytes). View file
 
backend/routes/__pycache__/routing.cpython-310.pyc ADDED
Binary file (375 Bytes). View file
 
backend/routes/routing.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ import sys
3
+ # sys.path.append('/Users/benolojo/DCU/CA4/ca400_FinalYearProject/2024-ca400-olojob2-majdap2/src/backend/src/')
4
+ from mongodb.endpoints import users, calls
5
+
6
+ router = APIRouter()
7
+ router.include_router(calls.router)
8
+ router.include_router(users.router)
9
+ # router.include_router(transcripts.router)
backend/seamless/__init__.py ADDED
File without changes
backend/seamless/room.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import json
2
+ import uuid
3
+
4
+
5
+ class Room:
6
+ def __init__(self, room_id) -> None:
7
+ self.room_id = room_id
8
+ # members is a dict from client_id to Member
9
+ self.members = {}
10
+
11
+ # listeners and speakers are lists of client_id's
12
+ self.listeners = []
13
+ self.speakers = []
14
+
15
+ def __str__(self) -> str:
16
+ return f"Room {self.room_id} ({len(self.members)} member{'s' if len(self.members) == 1 else ''})"
17
+
18
+ def to_json(self):
19
+ varsResult = vars(self)
20
+ # Remember: result is just a shallow copy, so result.members === self.members
21
+ # Because of that, we need to jsonify self.members without writing over result.members,
22
+ # which we do here via dictionary unpacking (the ** operator)
23
+ result = {
24
+ **varsResult,
25
+ "members": {key: value.to_json() for (key, value) in self.members.items()},
26
+ "activeTranscoders": self.get_active_transcoders(),
27
+ }
28
+
29
+ return result
30
+
31
+ def get_active_connections(self):
32
+ return len(
33
+ [m for m in self.members.values() if m.connection_status == "connected"]
34
+ )
35
+
36
+ def get_active_transcoders(self):
37
+ return len([m for m in self.members.values() if m.transcoder is not None])
38
+
39
+ def get_room_status_dict(self):
40
+ return {
41
+ "activeConnections": self.get_active_connections(),
42
+ "activeTranscoders": self.get_active_transcoders(),
43
+ }
44
+
45
+
46
+ class Member:
47
+ def __init__(self, client_id, session_id, name) -> None:
48
+ self.client_id = client_id
49
+ self.session_id = session_id
50
+ self.name = name
51
+ self.connection_status = "connected"
52
+ self.transcoder = None
53
+ self.requested_output_type = None
54
+ self.transcoder_dynamic_config = None
55
+
56
+ def __str__(self) -> str:
57
+ return f"{self.name} (id: {self.client_id[:4]}...) ({self.connection_status})"
58
+
59
+ def to_json(self):
60
+ self_vars = vars(self)
61
+ return {
62
+ **self_vars,
63
+ "transcoder": self.transcoder is not None,
64
+ }
backend/seamless/simuleval_agent_directory.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Creates a directory in which to look up available agents
2
+
3
+ import os
4
+ from typing import List, Optional
5
+ from seamless.simuleval_transcoder import SimulevalTranscoder
6
+ import json
7
+ import logging
8
+
9
+ logger = logging.getLogger("gunicorn")
10
+
11
+ # fmt: off
12
+ M4T_P0_LANGS = [
13
+ "eng",
14
+ "arb", "ben", "cat", "ces", "cmn", "cym", "dan",
15
+ "deu", "est", "fin", "fra", "hin", "ind", "ita",
16
+ "jpn", "kor", "mlt", "nld", "pes", "pol", "por",
17
+ "ron", "rus", "slk", "spa", "swe", "swh", "tel",
18
+ "tgl", "tha", "tur", "ukr", "urd", "uzn", "vie",
19
+ ]
20
+ # fmt: on
21
+
22
+
23
+ class NoAvailableAgentException(Exception):
24
+ pass
25
+
26
+
27
+ class AgentWithInfo:
28
+ def __init__(
29
+ self,
30
+ agent,
31
+ name: str,
32
+ modalities: List[str],
33
+ target_langs: List[str],
34
+ # Supported dynamic params are defined in StreamingTypes.ts
35
+ dynamic_params: List[str] = [],
36
+ description="",
37
+ has_expressive: Optional[bool] = None,
38
+ ):
39
+ self.agent = agent
40
+ self.has_expressive = has_expressive
41
+ self.name = name
42
+ self.description = description
43
+ self.modalities = modalities
44
+ self.target_langs = target_langs
45
+ self.dynamic_params = dynamic_params
46
+
47
+ def get_capabilities_for_json(self):
48
+ return {
49
+ "name": self.name,
50
+ "description": self.description,
51
+ "modalities": self.modalities,
52
+ "targetLangs": self.target_langs,
53
+ "dynamicParams": self.dynamic_params,
54
+ }
55
+
56
+ @classmethod
57
+ def load_from_json(cls, config: str):
58
+ """
59
+ Takes in JSON array of models to load in, e.g.
60
+ [{"name": "s2s_m4t_emma-unity2_multidomain_v0.1", "description": "M4T model that supports simultaneous S2S and S2T", "modalities": ["s2t", "s2s"], "targetLangs": ["en"]},
61
+ {"name": "s2s_m4t_expr-emma_v0.1", "description": "ES-EN expressive model that supports S2S and S2T", "modalities": ["s2t", "s2s"], "targetLangs": ["en"]}]
62
+ """
63
+ configs = json.loads(config)
64
+ agents = []
65
+ for config in configs:
66
+ agent = SimulevalTranscoder.build_agent(config["name"])
67
+ agents.append(
68
+ AgentWithInfo(
69
+ agent=agent,
70
+ name=config["name"],
71
+ modalities=config["modalities"],
72
+ target_langs=config["targetLangs"],
73
+ )
74
+ )
75
+ return agents
76
+
77
+
78
+ class SimulevalAgentDirectory:
79
+ # Available models. These are the directories where the models can be found, and also serve as an ID for the model.
80
+ seamless_streaming_agent = "SeamlessStreaming"
81
+ seamless_agent = "Seamless"
82
+
83
+ def __init__(self):
84
+ self.agents = []
85
+ self.did_build_and_add_agents = False
86
+
87
+ def add_agent(self, agent: AgentWithInfo):
88
+ self.agents.append(agent)
89
+
90
+ def build_agent_if_available(self, model_id, config_name=None):
91
+ agent = None
92
+ try:
93
+ if config_name is not None:
94
+ agent = SimulevalTranscoder.build_agent(
95
+ model_id,
96
+ config_name=config_name,
97
+ )
98
+ else:
99
+ agent = SimulevalTranscoder.build_agent(
100
+ model_id,
101
+ )
102
+ except Exception as e:
103
+ from fairseq2.assets.error import AssetError
104
+ logger.warning("Failed to build agent %s: %s" % (model_id, e))
105
+ if isinstance(e, AssetError):
106
+ logger.warning(
107
+ "Please download gated assets and set `gated_model_dir` in the config"
108
+ )
109
+ raise e
110
+
111
+ return agent
112
+
113
+ def build_and_add_agents(self, models_override=None):
114
+ if self.did_build_and_add_agents:
115
+ return
116
+
117
+ if models_override is not None:
118
+ agent_infos = AgentWithInfo.load_from_json(models_override)
119
+ for agent_info in agent_infos:
120
+ self.add_agent(agent_info)
121
+ else:
122
+ s2s_agent = None
123
+ if os.environ.get("USE_EXPRESSIVE_MODEL", "0") == "1":
124
+ logger.info("Building expressive model...")
125
+ s2s_agent = self.build_agent_if_available(
126
+ SimulevalAgentDirectory.seamless_agent,
127
+ config_name="vad_s2st_sc_24khz_main.yaml",
128
+ )
129
+ has_expressive = True
130
+ else:
131
+ logger.info("Building non-expressive model...")
132
+ s2s_agent = self.build_agent_if_available(
133
+ SimulevalAgentDirectory.seamless_streaming_agent,
134
+ config_name="vad_s2st_sc_main.yaml",
135
+ )
136
+ has_expressive = False
137
+
138
+ if s2s_agent:
139
+ self.add_agent(
140
+ AgentWithInfo(
141
+ agent=s2s_agent,
142
+ name=SimulevalAgentDirectory.seamless_streaming_agent,
143
+ modalities=["s2t", "s2s"],
144
+ target_langs=M4T_P0_LANGS,
145
+ dynamic_params=["expressive"],
146
+ description="multilingual expressive model that supports S2S and S2T",
147
+ has_expressive=has_expressive,
148
+ )
149
+ )
150
+
151
+ if len(self.agents) == 0:
152
+ logger.error(
153
+ "No agents were loaded. This likely means you are missing the actual model files specified in simuleval_agent_directory."
154
+ )
155
+
156
+ self.did_build_and_add_agents = True
157
+
158
+ def get_agent(self, name):
159
+ for agent in self.agents:
160
+ if agent.name == name:
161
+ return agent
162
+ return None
163
+
164
+ def get_agent_or_throw(self, name):
165
+ agent = self.get_agent(name)
166
+ if agent is None:
167
+ raise NoAvailableAgentException("No agent found with name= %s" % (name))
168
+ return agent
169
+
170
+ def get_agents_capabilities_list_for_json(self):
171
+ return [agent.get_capabilities_for_json() for agent in self.agents]
backend/seamless/simuleval_transcoder.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from simuleval.utils.agent import build_system_from_dir
2
+ from typing import Any, List, Optional, Tuple, Union
3
+ import numpy as np
4
+ import soundfile
5
+ import io
6
+ import asyncio
7
+ from simuleval.agents.pipeline import TreeAgentPipeline
8
+ from simuleval.agents.states import AgentStates
9
+ from simuleval.data.segments import Segment, EmptySegment, SpeechSegment
10
+ import threading
11
+ import math
12
+ import logging
13
+ import sys
14
+ from pathlib import Path
15
+ import time
16
+ from g2p_en import G2p
17
+ import torch
18
+ import traceback
19
+ import time
20
+ import random
21
+ import colorlog
22
+
23
+ from .speech_and_text_output import SpeechAndTextOutput
24
+
25
+ MODEL_SAMPLE_RATE = 16_000
26
+
27
+ logger = logging.getLogger(__name__)
28
+ # logger.propagate = False
29
+ handler = colorlog.StreamHandler(stream=sys.stdout)
30
+ formatter = colorlog.ColoredFormatter(
31
+ "%(log_color)s[%(asctime)s][%(levelname)s][%(module)s]:%(reset)s %(message)s",
32
+ reset=True,
33
+ log_colors={
34
+ "DEBUG": "cyan",
35
+ "INFO": "green",
36
+ "WARNING": "yellow",
37
+ "ERROR": "red",
38
+ "CRITICAL": "red,bg_white",
39
+ },
40
+ )
41
+ handler.setFormatter(formatter)
42
+ logger.addHandler(handler)
43
+ logger.setLevel(logging.WARNING)
44
+
45
+
46
+ class OutputSegments:
47
+ def __init__(self, segments: Union[List[Segment], Segment]):
48
+ if isinstance(segments, Segment):
49
+ segments = [segments]
50
+ self.segments: List[Segment] = [s for s in segments]
51
+
52
+ @property
53
+ def is_empty(self):
54
+ return all(segment.is_empty for segment in self.segments)
55
+
56
+ @property
57
+ def finished(self):
58
+ return all(segment.finished for segment in self.segments)
59
+
60
+ def compute_length(self, g2p):
61
+ lengths = []
62
+ for segment in self.segments:
63
+ if segment.data_type == "text":
64
+ lengths.append(len([x for x in g2p(segment.content) if x != " "]))
65
+ elif segment.data_type == "speech":
66
+ lengths.append(len(segment.content) / MODEL_SAMPLE_RATE)
67
+ elif isinstance(segment, EmptySegment):
68
+ continue
69
+ else:
70
+ logger.warning(
71
+ f"Unexpected data_type: {segment.data_type} not in 'speech', 'text'"
72
+ )
73
+ return max(lengths)
74
+
75
+ @classmethod
76
+ def join_output_buffer(
77
+ cls, buffer: List[List[Segment]], output: SpeechAndTextOutput
78
+ ):
79
+ num_segments = len(buffer[0])
80
+ for i in range(num_segments):
81
+ segment_list = [
82
+ buffer[j][i]
83
+ for j in range(len(buffer))
84
+ if buffer[j][i].data_type is not None
85
+ ]
86
+ if len(segment_list) == 0:
87
+ continue
88
+ if len(set(segment.data_type for segment in segment_list)) != 1:
89
+ logger.warning(
90
+ f"Data type mismatch at {i}: {set(segment.data_type for segment in segment_list)}"
91
+ )
92
+ continue
93
+ data_type = segment_list[0].data_type
94
+ if data_type == "text":
95
+ if output.text is not None:
96
+ logger.warning("Multiple text outputs, overwriting!")
97
+ output.text = " ".join([segment.content for segment in segment_list])
98
+ elif data_type == "speech":
99
+ if output.speech_samples is not None:
100
+ logger.warning("Multiple speech outputs, overwriting!")
101
+ speech_out = []
102
+ for segment in segment_list:
103
+ speech_out += segment.content
104
+ output.speech_samples = speech_out
105
+ output.speech_sample_rate = segment.sample_rate
106
+ elif isinstance(segment_list[0], EmptySegment):
107
+ continue
108
+ else:
109
+ logger.warning(
110
+ f"Invalid output buffer data type: {data_type}, expected 'speech' or 'text"
111
+ )
112
+
113
+ return output
114
+
115
+ def __repr__(self) -> str:
116
+ repr_str = str(self.segments)
117
+ return f"{self.__class__.__name__}(\n\t{repr_str}\n)"
118
+
119
+
120
+ class SimulevalTranscoder:
121
+ def __init__(self, agent, sample_rate, debug, buffer_limit):
122
+ self.agent = agent.agent
123
+ self.has_expressive = agent.has_expressive
124
+ self.input_queue = asyncio.Queue()
125
+ self.output_queue = asyncio.Queue()
126
+ self.states = self.agent.build_states()
127
+ if debug:
128
+ self.get_states_root().debug = True
129
+ self.incoming_sample_rate = sample_rate
130
+ self.close = False
131
+ self.g2p = G2p()
132
+
133
+ # buffer all outgoing translations within this amount of time
134
+ self.output_buffer_idle_ms = 5000
135
+ self.output_buffer_size_limit = (
136
+ buffer_limit # phonemes for text, seconds for speech
137
+ )
138
+ self.output_buffer_cur_size = 0
139
+ self.output_buffer: List[List[Segment]] = []
140
+ self.speech_output_sample_rate = None
141
+
142
+ self.last_output_ts = time.time() * 1000
143
+ self.timeout_ms = (
144
+ 30000 # close the transcoder thread after this amount of silence
145
+ )
146
+ self.first_input_ts = None
147
+ self.first_output_ts = None
148
+ self.debug = debug
149
+ self.debug_ts = f"{time.time()}_{random.randint(1000, 9999)}"
150
+ if self.debug:
151
+ debug_folder = Path(__file__).resolve().parent.parent / "debug"
152
+ self.test_incoming_wav = soundfile.SoundFile(
153
+ debug_folder / f"{self.debug_ts}_test_incoming.wav",
154
+ mode="w+",
155
+ format="WAV",
156
+ subtype="PCM_16",
157
+ samplerate=self.incoming_sample_rate,
158
+ channels=1,
159
+ )
160
+ self.get_states_root().test_input_segments_wav = soundfile.SoundFile(
161
+ debug_folder / f"{self.debug_ts}_test_input_segments.wav",
162
+ mode="w+",
163
+ format="WAV",
164
+ samplerate=MODEL_SAMPLE_RATE,
165
+ channels=1,
166
+ )
167
+
168
+ def get_states_root(self) -> AgentStates:
169
+ if isinstance(self.agent, TreeAgentPipeline):
170
+ # self.states is a dict
171
+ return self.states[self.agent.source_module]
172
+ else:
173
+ # self.states is a list
174
+ return self.states[0]
175
+
176
+ def reset_states(self):
177
+ if isinstance(self.agent, TreeAgentPipeline):
178
+ states_iter = self.states.values()
179
+ else:
180
+ states_iter = self.states
181
+ for state in states_iter:
182
+ state.reset()
183
+
184
+ def debug_log(self, *args):
185
+ if self.debug:
186
+ logger.info(*args)
187
+
188
+ @classmethod
189
+ def build_agent(cls, model_path, config_name):
190
+ logger.info(f"Building simuleval agent: {model_path}, {config_name}")
191
+ agent = build_system_from_dir(
192
+ Path(__file__).resolve().parent.parent / f"models/{model_path}",
193
+ config_name=config_name,
194
+ )
195
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
196
+ agent.to(device, fp16=True)
197
+ logger.info(
198
+ f"Successfully built simuleval agent {model_path} on device {device}"
199
+ )
200
+
201
+ return agent
202
+
203
+ def process_incoming_bytes(self, incoming_bytes, dynamic_config):
204
+ # TODO: We probably want to do some validation on dynamic_config to ensure it has what we needs
205
+ segment, sr = self._preprocess_wav(incoming_bytes)
206
+ segment = SpeechSegment(
207
+ content=segment,
208
+ sample_rate=sr,
209
+ tgt_lang=dynamic_config.get("targetLanguage"),
210
+ config=dynamic_config,
211
+ )
212
+ if dynamic_config.get("expressive") is True and self.has_expressive is False:
213
+ logger.warning(
214
+ "Passing 'expressive' but the agent does not support expressive output!"
215
+ )
216
+ # # segment is array([0, 0, 0, ..., 0, 0, 0], dtype=int16)
217
+ self.input_queue.put_nowait(segment)
218
+
219
+ def get_input_segment(self):
220
+ if self.input_queue.empty():
221
+ return None
222
+ chunk = self.input_queue.get_nowait()
223
+ self.input_queue.task_done()
224
+ return chunk
225
+
226
+ def convert_waveform(
227
+ self,
228
+ waveform: Union[np.ndarray, torch.Tensor],
229
+ sample_rate: int,
230
+ normalize_volume: bool = False,
231
+ to_mono: bool = False,
232
+ to_sample_rate: Optional[int] = None,
233
+ ) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
234
+ """convert a waveform:
235
+ - to a target sample rate
236
+ - from multi-channel to mono channel
237
+ - volume normalization
238
+
239
+ Args:
240
+ waveform (numpy.ndarray or torch.Tensor): 2D original waveform
241
+ (channels x length)
242
+ sample_rate (int): original sample rate
243
+ normalize_volume (bool): perform volume normalization
244
+ to_mono (bool): convert to mono channel if having multiple channels
245
+ to_sample_rate (Optional[int]): target sample rate
246
+ Returns:
247
+ waveform (numpy.ndarray): converted 2D waveform (channels x length)
248
+ sample_rate (float): target sample rate
249
+ """
250
+ try:
251
+ import torchaudio.sox_effects as ta_sox
252
+ except ImportError:
253
+ raise ImportError("Please install torchaudio: pip install torchaudio")
254
+
255
+ effects = []
256
+ if normalize_volume:
257
+ effects.append(["gain", "-n"])
258
+ if to_sample_rate is not None and to_sample_rate != sample_rate:
259
+ effects.append(["rate", f"{to_sample_rate}"])
260
+ if to_mono and waveform.shape[0] > 1:
261
+ effects.append(["channels", "1"])
262
+ if len(effects) > 0:
263
+ is_np_input = isinstance(waveform, np.ndarray)
264
+ _waveform = torch.from_numpy(waveform) if is_np_input else waveform
265
+ converted, converted_sample_rate = ta_sox.apply_effects_tensor(
266
+ _waveform, sample_rate, effects
267
+ )
268
+ if is_np_input:
269
+ converted = converted.numpy()
270
+ return converted, converted_sample_rate
271
+ return waveform, sample_rate
272
+
273
+ def _preprocess_wav(self, data: Any) -> Tuple[np.ndarray, int]:
274
+ segment, sample_rate = soundfile.read(
275
+ io.BytesIO(data),
276
+ dtype="float32",
277
+ always_2d=True,
278
+ frames=-1,
279
+ start=0,
280
+ format="RAW",
281
+ subtype="PCM_16",
282
+ samplerate=self.incoming_sample_rate,
283
+ channels=1,
284
+ )
285
+ if self.debug:
286
+ self.test_incoming_wav.seek(0, soundfile.SEEK_END)
287
+ self.test_incoming_wav.write(segment)
288
+
289
+ segment = segment.T
290
+ segment, new_sample_rate = self.convert_waveform(
291
+ segment,
292
+ sample_rate,
293
+ normalize_volume=False,
294
+ to_mono=True,
295
+ to_sample_rate=MODEL_SAMPLE_RATE,
296
+ )
297
+
298
+ assert MODEL_SAMPLE_RATE == new_sample_rate
299
+ segment = segment.squeeze(axis=0)
300
+ return segment, new_sample_rate
301
+
302
+ def process_pipeline_impl(self, input_segment):
303
+ try:
304
+ with torch.no_grad():
305
+ output_segment = OutputSegments(
306
+ self.agent.pushpop(input_segment, self.states)
307
+ )
308
+ if (
309
+ self.get_states_root().first_input_ts is not None
310
+ and self.first_input_ts is None
311
+ ):
312
+ # TODO: this is hacky
313
+ self.first_input_ts = self.get_states_root().first_input_ts
314
+
315
+ if not output_segment.is_empty:
316
+ self.output_queue.put_nowait(output_segment)
317
+
318
+ if output_segment.finished:
319
+ self.debug_log("OUTPUT SEGMENT IS FINISHED. Resetting states.")
320
+
321
+ self.reset_states()
322
+
323
+ if self.debug:
324
+ # when we rebuild states, this value is reset to whatever
325
+ # is in the system dir config, which defaults debug=False.
326
+ self.get_states_root().debug = True
327
+ except Exception as e:
328
+ logger.error(f"Got exception while processing pipeline: {e}")
329
+ traceback.print_exc()
330
+ return input_segment
331
+
332
+ def process_pipeline_loop(self):
333
+ if self.close:
334
+ return # closes the thread
335
+
336
+ self.debug_log("processing_pipeline")
337
+ while not self.close:
338
+ input_segment = self.get_input_segment()
339
+ if input_segment is None:
340
+ if self.get_states_root().is_fresh_state: # TODO: this is hacky
341
+ time.sleep(0.3)
342
+ else:
343
+ time.sleep(0.03)
344
+ continue
345
+ self.process_pipeline_impl(input_segment)
346
+ self.debug_log("finished processing_pipeline")
347
+
348
+ def process_pipeline_once(self):
349
+ if self.close:
350
+ return
351
+
352
+ self.debug_log("processing pipeline once")
353
+ input_segment = self.get_input_segment()
354
+ if input_segment is None:
355
+ return
356
+ self.process_pipeline_impl(input_segment)
357
+ self.debug_log("finished processing_pipeline_once")
358
+
359
+ def get_output_segment(self):
360
+ if self.output_queue.empty():
361
+ return None
362
+
363
+ output_chunk = self.output_queue.get_nowait()
364
+ self.output_queue.task_done()
365
+ return output_chunk
366
+
367
+ def start(self):
368
+ self.debug_log("starting transcoder in a thread")
369
+ threading.Thread(target=self.process_pipeline_loop).start()
370
+
371
+ def first_translation_time(self):
372
+ return round((self.first_output_ts - self.first_input_ts) / 1000, 2)
373
+
374
+ def get_buffered_output(self) -> SpeechAndTextOutput:
375
+ now = time.time() * 1000
376
+ self.debug_log(f"get_buffered_output queue size: {self.output_queue.qsize()}")
377
+ while not self.output_queue.empty():
378
+ tmp_out = self.get_output_segment()
379
+ if tmp_out and tmp_out.compute_length(self.g2p) > 0:
380
+ if len(self.output_buffer) == 0:
381
+ self.last_output_ts = now
382
+ self._populate_output_buffer(tmp_out)
383
+ self._increment_output_buffer_size(tmp_out)
384
+
385
+ if tmp_out.finished:
386
+ self.debug_log("tmp_out.finished")
387
+ res = self._gather_output_buffer_data(final=True)
388
+ self.debug_log(f"gathered output data: {res}")
389
+ self.output_buffer = []
390
+ self.increment_output_buffer_size = 0
391
+ self.last_output_ts = now
392
+ self.first_output_ts = now
393
+ return res
394
+ else:
395
+ self.debug_log("tmp_out.compute_length is not > 0")
396
+
397
+ if len(self.output_buffer) > 0 and (
398
+ now - self.last_output_ts >= self.output_buffer_idle_ms
399
+ or self.output_buffer_cur_size >= self.output_buffer_size_limit
400
+ ):
401
+ self.debug_log(
402
+ "[get_buffered_output] output_buffer is not empty. getting res to return."
403
+ )
404
+ self.last_output_ts = now
405
+ res = self._gather_output_buffer_data(final=False)
406
+ self.debug_log(f"gathered output data: {res}")
407
+ self.output_buffer = []
408
+ self.output_buffer_phoneme_count = 0
409
+ self.first_output_ts = now
410
+ return res
411
+ else:
412
+ self.debug_log("[get_buffered_output] output_buffer is empty...")
413
+ return None
414
+
415
+ def _gather_output_buffer_data(self, final):
416
+ output = SpeechAndTextOutput()
417
+ output.final = final
418
+ output = OutputSegments.join_output_buffer(self.output_buffer, output)
419
+ return output
420
+
421
+ def _increment_output_buffer_size(self, segment: OutputSegments):
422
+ self.output_buffer_cur_size += segment.compute_length(self.g2p)
423
+
424
+ def _populate_output_buffer(self, segment: OutputSegments):
425
+ self.output_buffer.append(segment.segments)
426
+
427
+ def _compute_phoneme_count(self, string: str) -> int:
428
+ return len([x for x in self.g2p(string) if x != " "])
backend/seamless/speech_and_text_output.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Provides a container to return both speech and text output from our model at the same time
2
+
3
+
4
+ class SpeechAndTextOutput:
5
+ def __init__(
6
+ self,
7
+ text: str = None,
8
+ speech_samples: list = None,
9
+ speech_sample_rate: float = None,
10
+ final: bool = False,
11
+ ):
12
+ self.text = text
13
+ self.speech_samples = speech_samples
14
+ self.speech_sample_rate = speech_sample_rate
15
+ self.final = final
backend/seamless/transcoder_helpers.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ logger = logging.getLogger("gunicorn")
4
+
5
+
6
+ def get_transcoder_output_events(transcoder) -> list:
7
+ speech_and_text_output = transcoder.get_buffered_output()
8
+ if speech_and_text_output is None:
9
+ logger.debug("No output from transcoder.get_buffered_output()")
10
+ return []
11
+
12
+ logger.debug(f"We DID get output from the transcoder! {speech_and_text_output}")
13
+
14
+ lat = None
15
+
16
+ events = []
17
+
18
+ if speech_and_text_output.speech_samples:
19
+ events.append(
20
+ {
21
+ "event": "translation_speech",
22
+ "payload": speech_and_text_output.speech_samples,
23
+ "sample_rate": speech_and_text_output.speech_sample_rate,
24
+ }
25
+ )
26
+
27
+ if speech_and_text_output.text:
28
+ events.append(
29
+ {
30
+ "event": "translation_text",
31
+ "payload": speech_and_text_output.text,
32
+ }
33
+ )
34
+
35
+ for e in events:
36
+ e["eos"] = speech_and_text_output.final
37
+
38
+ # if not latency_sent:
39
+ # lat = transcoder.first_translation_time()
40
+ # latency_sent = True
41
+ # to_send["latency"] = lat
42
+
43
+ return events
backend/seamless_utils.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # base seamless imports
3
+ # ---------------------------------
4
+ import io
5
+ import json
6
+ import matplotlib as mpl
7
+ import matplotlib.pyplot as plt
8
+ import mmap
9
+ import numpy as np
10
+ import soundfile
11
+ import torchaudio
12
+ import torch
13
+ from pydub import AudioSegment
14
+ # ---------------------------------
15
+ # seamless-streaming specific imports
16
+ # ---------------------------------
17
+ import math
18
+ from simuleval.data.segments import SpeechSegment, EmptySegment
19
+ from seamless_communication.streaming.agents.seamless_streaming_s2st import (
20
+ SeamlessStreamingS2STVADAgent,
21
+ )
22
+
23
+ from simuleval.utils.arguments import cli_argument_list
24
+ from simuleval import options
25
+
26
+
27
+ from typing import Union, List
28
+ from simuleval.data.segments import Segment, TextSegment
29
+ from simuleval.agents.pipeline import TreeAgentPipeline
30
+ from simuleval.agents.states import AgentStates
31
+ # ---------------------------------
32
+ # seamless setup
33
+ # source: https://colab.research.google.com/github/kauterry/seamless_communication/blob/main/Seamless_Tutorial.ipynb?
34
+ SAMPLE_RATE = 16000
35
+
36
+ # PM - THis class is used to simulate the audio frontend in the seamless streaming pipeline
37
+ # need to replace this with the actual audio frontend
38
+ # TODO: replacement class that takes in PCM-16 bytes and returns SpeechSegment
39
+ class AudioFrontEnd:
40
+ def __init__(self, wav_file, segment_size) -> None:
41
+ self.samples, self.sample_rate = soundfile.read(wav_file)
42
+ print(self.sample_rate, "sample rate")
43
+ assert self.sample_rate == SAMPLE_RATE
44
+ # print(len(self.samples), self.samples[:100])
45
+ self.samples = self.samples # .tolist()
46
+ self.segment_size = segment_size
47
+ self.step = 0
48
+
49
+ def send_segment(self):
50
+ """
51
+ This is the front-end logic in simuleval instance.py
52
+ """
53
+
54
+ num_samples = math.ceil(self.segment_size / 1000 * self.sample_rate)
55
+
56
+ if self.step < len(self.samples):
57
+ if self.step + num_samples >= len(self.samples):
58
+ samples = self.samples[self.step :]
59
+ is_finished = True
60
+ else:
61
+ samples = self.samples[self.step : self.step + num_samples]
62
+ is_finished = False
63
+ self.samples = self.samples[self.step:]
64
+ self.step = min(self.step + num_samples, len(self.samples))
65
+ segment = SpeechSegment(
66
+ content=samples,
67
+ sample_rate=self.sample_rate,
68
+ finished=is_finished,
69
+ )
70
+ else:
71
+ # Finish reading this audio
72
+ segment = EmptySegment(
73
+ finished=True,
74
+ )
75
+ self.step = 0
76
+ self.samples = []
77
+ return segment
78
+
79
+ # samples = self.samples[:num_samples]
80
+ # self.samples = self.samples[num_samples:]
81
+ # segment = SpeechSegment(
82
+ # content=samples,
83
+ # sample_rate=self.sample_rate,
84
+ # finished=False,
85
+ # )
86
+
87
+
88
+ def add_segments(self, wav):
89
+ new_samples, _ = soundfile.read(wav)
90
+ self.samples = np.concatenate((self.samples, new_samples))
91
+
92
+
93
+ class OutputSegments:
94
+ def __init__(self, segments: Union[List[Segment], Segment]):
95
+ if isinstance(segments, Segment):
96
+ segments = [segments]
97
+ self.segments: List[Segment] = [s for s in segments]
98
+
99
+ @property
100
+ def is_empty(self):
101
+ return all(segment.is_empty for segment in self.segments)
102
+
103
+ @property
104
+ def finished(self):
105
+ return all(segment.finished for segment in self.segments)
106
+
107
+
108
+ def get_audiosegment(samples, sr):
109
+ b = io.BytesIO()
110
+ soundfile.write(b, samples, samplerate=sr, format="wav")
111
+ b.seek(0)
112
+ return AudioSegment.from_file(b)
113
+
114
+
115
+ def reset_states(system, states):
116
+ if isinstance(system, TreeAgentPipeline):
117
+ states_iter = states.values()
118
+ else:
119
+ states_iter = states
120
+ for state in states_iter:
121
+ state.reset()
122
+
123
+
124
+ def get_states_root(system, states) -> AgentStates:
125
+ if isinstance(system, TreeAgentPipeline):
126
+ # self.states is a dict
127
+ return states[system.source_module]
128
+ else:
129
+ # self.states is a list
130
+ return system.states[0]
131
+
132
+
133
+ def build_streaming_system(model_configs, agent_class):
134
+ parser = options.general_parser()
135
+ parser.add_argument("-f", "--f", help="a dummy argument to fool ipython", default="1")
136
+
137
+ agent_class.add_args(parser)
138
+ args, _ = parser.parse_known_args(cli_argument_list(model_configs))
139
+ system = agent_class.from_args(args)
140
+ return system
141
+
142
+
143
+ def run_streaming_inference(system, audio_frontend, system_states, tgt_lang):
144
+ # NOTE: Here for visualization, we calculate delays offset from audio
145
+ # *BEFORE* VAD segmentation.
146
+ # In contrast for SimulEval evaluation, we assume audios are pre-segmented,
147
+ # and Average Lagging, End Offset metrics are based on those pre-segmented audios.
148
+ # Thus, delays here are *NOT* comparable to SimulEval per-segment delays
149
+ delays = {"s2st": [], "s2tt": []}
150
+ prediction_lists = {"s2st": [], "s2tt": []}
151
+ speech_durations = []
152
+ curr_delay = 0
153
+ target_sample_rate = None
154
+
155
+ while True:
156
+ input_segment = audio_frontend.send_segment()
157
+ input_segment.tgt_lang = tgt_lang
158
+ curr_delay += len(input_segment.content) / SAMPLE_RATE * 1000
159
+ if input_segment.finished:
160
+ # a hack, we expect a real stream to end with silence
161
+ get_states_root(system, system_states).source_finished = True
162
+ # Translation happens here
163
+ if isinstance(input_segment, EmptySegment):
164
+ return None, None, None, None
165
+ output_segments = OutputSegments(system.pushpop(input_segment, system_states))
166
+ if not output_segments.is_empty:
167
+ for segment in output_segments.segments:
168
+ # NOTE: another difference from SimulEval evaluation -
169
+ # delays are accumulated per-token
170
+ if isinstance(segment, SpeechSegment):
171
+ pred_duration = 1000 * len(segment.content) / segment.sample_rate
172
+ speech_durations.append(pred_duration)
173
+ delays["s2st"].append(curr_delay)
174
+ prediction_lists["s2st"].append(segment.content)
175
+ target_sample_rate = segment.sample_rate
176
+ elif isinstance(segment, TextSegment):
177
+ delays["s2tt"].append(curr_delay)
178
+ prediction_lists["s2tt"].append(segment.content)
179
+ print(curr_delay, segment.content)
180
+ if output_segments.finished:
181
+ reset_states(system, system_states)
182
+ if input_segment.finished:
183
+ # an assumption of SimulEval agents -
184
+ # once source_finished=True, generate until output translation is finished
185
+ break
186
+ return delays, prediction_lists, speech_durations, target_sample_rate
187
+
188
+
189
+ def get_s2st_delayed_targets(delays, target_sample_rate, prediction_lists, speech_durations):
190
+ # get calculate intervals + durations for s2st
191
+ intervals = []
192
+
193
+ start = prev_end = prediction_offset = delays["s2st"][0]
194
+ target_samples = [0.0] * int(target_sample_rate * prediction_offset / 1000)
195
+
196
+ for i, delay in enumerate(delays["s2st"]):
197
+ start = max(prev_end, delay)
198
+
199
+ if start > prev_end:
200
+ # Wait source speech, add discontinuity with silence
201
+ target_samples += [0.0] * int(
202
+ target_sample_rate * (start - prev_end) / 1000
203
+ )
204
+
205
+ target_samples += prediction_lists["s2st"][i]
206
+ duration = speech_durations[i]
207
+ prev_end = start + duration
208
+ intervals.append([start, duration])
209
+ return target_samples, intervals
210
+
backend/tests/__pycache__/test_client.cpython-310-pytest-8.1.1.pyc ADDED
Binary file (6.82 kB). View file
 
backend/tests/__pycache__/test_main.cpython-310-pytest-8.1.1.pyc ADDED
Binary file (3.38 kB). View file
 
backend/tests/__pycache__/test_main.cpython-310.pyc ADDED
Binary file (2.2 kB). View file
 
backend/tests/silence.wav ADDED
Binary file (302 kB). View file
 
backend/tests/speaking.wav ADDED
Binary file (255 kB). View file
 
backend/tests/test_client.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wave
3
+ import pytest
4
+ import torchaudio
5
+ import os
6
+ import sys
7
+
8
+ current_dir = os.path.dirname(os.path.abspath(__file__))
9
+ parent_dir = os.path.dirname(current_dir)
10
+ sys.path.append(parent_dir)
11
+ from Client import Client
12
+
13
+
14
+ @pytest.fixture
15
+ def mock_client():
16
+ client = Client("test_sid", "test_client_id", original_sr=44100)
17
+ return client
18
+
19
+ def test_client_init(mock_client):
20
+ assert mock_client.sid == "test_sid"
21
+ assert mock_client.client_id == "test_client_id"
22
+ assert mock_client.call_id == None
23
+ assert mock_client.buffer == bytearray()
24
+ assert mock_client.output_path == "test_sid_output_audio.wav"
25
+ assert mock_client.target_language == None
26
+ assert mock_client.original_sr == 44100
27
+ assert mock_client.vad.sample_rate == 16000
28
+ assert mock_client.vad.frame_length == 25
29
+ assert mock_client.vad.frame_shift == 20
30
+ assert mock_client.vad.energy_threshold == 0.05
31
+ assert mock_client.vad.pre_emphasis == 0.95
32
+
33
+ def test_client_add_bytes(mock_client):
34
+ mock_client.add_bytes(b"test")
35
+ assert mock_client.buffer == b"test"
36
+
37
+ def test_client_resample_and_clear(mock_client):
38
+ location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
39
+ speaking_bytes = wave.open(location + "/speaking.wav", "rb").readframes(-1)
40
+ mock_client.add_bytes(speaking_bytes)
41
+ resampled_waveform = mock_client.resample_and_clear()
42
+ torchaudio.save(location + "testoutput.wav", resampled_waveform, 16000)
43
+ with wave.open(location + "testoutput.wav", "rb") as wf:
44
+ sample_rate = wf.getframerate()
45
+ assert mock_client.buffer == bytearray()
46
+ assert sample_rate == 16000
47
+
48
+ def test_client_vad(mock_client):
49
+ location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
50
+ speaking_bytes = wave.open(location + "/speaking.wav", "rb").readframes(-1)
51
+ mock_client.add_bytes(speaking_bytes)
52
+ resampled_waveform = mock_client.resample_and_clear()
53
+ assert mock_client.buffer == bytearray()
54
+ assert mock_client.vad_analyse(resampled_waveform) == True
55
+ silent_bytes = wave.open(location + "/silence.wav", "rb").readframes(-1)
56
+ mock_client.add_bytes(silent_bytes)
57
+ resampled_waveform = mock_client.resample_and_clear()
58
+ assert mock_client.buffer == bytearray()
59
+ assert mock_client.vad_analyse(resampled_waveform) == False
backend/tests/test_main.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ import pytest
3
+ from unittest.mock import AsyncMock, MagicMock, ANY
4
+ import socketio
5
+
6
+ import os
7
+ import sys
8
+
9
+ current_dir = os.path.dirname(os.path.abspath(__file__))
10
+ parent_dir = os.path.dirname(current_dir)
11
+ sys.path.append(parent_dir)
12
+
13
+ from Client import Client
14
+ from main import sio, connect, disconnect, target_language, call_user, answer_call, incoming_audio, clients, rooms
15
+ from unittest.mock import patch
16
+
17
+ sio = socketio.AsyncServer(
18
+ async_mode="asgi",
19
+ cors_allowed_origins="*",
20
+ # engineio_logger=logger,
21
+ )
22
+ # sio.logger.setLevel(logging.DEBUG)
23
+ socketio_app = socketio.ASGIApp(sio)
24
+
25
+ app = FastAPI()
26
+ app.mount("/", socketio_app)
27
+
28
+ @pytest.fixture(autouse=True)
29
+ def setup_clients_and_rooms():
30
+ global clients, rooms
31
+ clients.clear()
32
+ rooms.clear()
33
+ yield
34
+
35
+ @pytest.fixture
36
+ def mock_client():
37
+ client = Client("test_sid", "test_client_id", original_sr=44100)
38
+ return client
39
+
40
+
41
+ @pytest.mark.asyncio
42
+ async def test_connect(mock_client):
43
+ sid = mock_client.sid
44
+ environ = {'QUERY_STRING': 'client_id=test_client_id'}
45
+ await connect(sid, environ)
46
+ assert sid in clients
47
+
48
+ @pytest.mark.asyncio
49
+ async def test_disconnect(mock_client):
50
+ sid = mock_client.sid
51
+ clients[sid] = mock_client
52
+ await disconnect(sid)
53
+ assert sid not in clients
54
+
55
+ @pytest.mark.asyncio
56
+ async def test_target_language(mock_client):
57
+ sid = mock_client.sid
58
+ clients[sid] = mock_client
59
+ target_lang = "fr"
60
+ await target_language(sid, target_lang)
61
+ assert clients[sid].target_language == "fr"
62
+
63
+ # PM - issues with socketio enter_room in these tests
64
+ # @pytest.mark.asyncio
65
+ # async def test_call_user(mock_client):
66
+ # sid = mock_client.sid
67
+ # clients[sid] = mock_client
68
+ # call_id = "1234"
69
+ # await call_user(sid, call_id)
70
+ # assert call_id in rooms
71
+ # assert sid in rooms[call_id]
72
+
73
+ # @pytest.mark.asyncio
74
+ # async def test_answer_call(mock_client):
75
+ # sid = mock_client.sid
76
+ # clients[sid] = mock_client
77
+ # call_id = "1234"
78
+ # await answer_call(sid, call_id)
79
+ # assert call_id in rooms
80
+ # assert sid in rooms[call_id]
81
+
82
+ @pytest.mark.asyncio
83
+ async def test_incoming_audio(mock_client):
84
+ sid = mock_client.sid
85
+ clients[sid] = mock_client
86
+ data = b"\x01"
87
+ call_id = "1234"
88
+ await incoming_audio(sid, data, call_id)
89
+ assert clients[sid].get_length() != 0
90
+
backend/utils/__pycache__/text_rank.cpython-310.pyc ADDED
Binary file (2.03 kB). View file
 
backend/utils/text_rank.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+ import pytextrank
3
+ from spacy.tokens import Span
4
+
5
+ # Define decorator for converting to singular version of words
6
+ @spacy.registry.misc("plural_scrubber")
7
+ def plural_scrubber():
8
+ def scrubber_func(span: Span) -> str:
9
+ return span.lemma_
10
+ return scrubber_func
11
+
12
+
13
+ def model_selector(target_language: str):
14
+
15
+ # Load subset of non-english models
16
+ language_model = {
17
+ "spa": "es_core_news_sm",
18
+ "fra": "fr_core_news_sm",
19
+ "pol": "pl_core_news_sm",
20
+ "deu": "de_core_news_sm",
21
+ "ita": "it_core_news_sm",
22
+ "por": "pt_core_news_sm",
23
+ "nld": "nl_core_news_sm",
24
+ "fin": "fi_core_news_sm",
25
+ "ron": "ro_core_news_sm",
26
+ "rus": "ru_core_news_sm"
27
+ }
28
+
29
+ try:
30
+ nlp = spacy.load(language_model[target_language])
31
+
32
+ except KeyError:
33
+ # Load a spaCy English model
34
+ nlp = spacy.load("en_core_web_lg")
35
+
36
+ # Add TextRank component to pipeline with stopwords
37
+ nlp.add_pipe("textrank", config={
38
+ "stopwords": {token:["NOUN"] for token in nlp.Defaults.stop_words},
39
+ "scrubber": {"@misc": "plural_scrubber"}})
40
+
41
+ return nlp
42
+
43
+
44
+ def extract_terms(text, target_language, length):
45
+ nlp = model_selector(target_language)
46
+
47
+ # Perform fact extraction on overall summary and segment summaries
48
+ doc = nlp(text)
49
+
50
+ if length < 100:
51
+ # Get single most used key term
52
+ phrases = {phrase.text for phrase in doc._.phrases[:1]}
53
+ elif length > 100 and length < 300:
54
+ # Create unique set from top 2 ranked phrases
55
+ phrases = {phrase.text for phrase in doc._.phrases[:2]}
56
+ if length > 300:
57
+ # Create unique set from top 3 ranked phrases
58
+ phrases = {phrase.text for phrase in doc._.phrases[:3]}
59
+
60
+ return list(phrases)