Spaces:
Paused
Paused
Fuzzy search fix
Browse files- backend/.DS_Store +0 -0
- backend/.env +2 -0
- backend/.gitignore +2 -0
- backend/Client.py +81 -0
- backend/__pycache__/Client.cpython-310.pyc +0 -0
- backend/__pycache__/main.cpython-310.pyc +0 -0
- backend/logging.yaml +22 -0
- backend/main.py +343 -0
- backend/models/Seamless/vad_s2st_sc_24khz_main.yaml +25 -0
- backend/models/SeamlessStreaming/vad_s2st_sc_main.yaml +21 -0
- backend/mongodb/endpoints/__pycache__/calls.cpython-310.pyc +0 -0
- backend/mongodb/endpoints/__pycache__/users.cpython-310.pyc +0 -0
- backend/mongodb/endpoints/calls.py +96 -0
- backend/mongodb/endpoints/users.py +53 -0
- backend/mongodb/models/__pycache__/calls.cpython-310.pyc +0 -0
- backend/mongodb/models/__pycache__/users.cpython-310.pyc +0 -0
- backend/mongodb/models/calls.py +75 -0
- backend/mongodb/models/users.py +44 -0
- backend/mongodb/operations/__pycache__/calls.cpython-310.pyc +0 -0
- backend/mongodb/operations/__pycache__/users.cpython-310.pyc +0 -0
- backend/mongodb/operations/calls.py +285 -0
- backend/mongodb/operations/users.py +77 -0
- backend/pcmToWav.py +34 -0
- backend/preprocess_wav.py +65 -0
- backend/requirements.txt +28 -0
- backend/routes/__init__.py +1 -0
- backend/routes/__pycache__/__init__.cpython-310.pyc +0 -0
- backend/routes/__pycache__/routing.cpython-310.pyc +0 -0
- backend/routes/routing.py +9 -0
- backend/seamless/__init__.py +0 -0
- backend/seamless/room.py +64 -0
- backend/seamless/simuleval_agent_directory.py +171 -0
- backend/seamless/simuleval_transcoder.py +428 -0
- backend/seamless/speech_and_text_output.py +15 -0
- backend/seamless/transcoder_helpers.py +43 -0
- backend/seamless_utils.py +210 -0
- backend/tests/__pycache__/test_client.cpython-310-pytest-8.1.1.pyc +0 -0
- backend/tests/__pycache__/test_main.cpython-310-pytest-8.1.1.pyc +0 -0
- backend/tests/__pycache__/test_main.cpython-310.pyc +0 -0
- backend/tests/silence.wav +0 -0
- backend/tests/speaking.wav +0 -0
- backend/tests/test_client.py +59 -0
- backend/tests/test_main.py +90 -0
- backend/utils/__pycache__/text_rank.cpython-310.pyc +0 -0
- backend/utils/text_rank.py +60 -0
backend/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
backend/.env
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
MONGODB_URI=mongodb+srv://benjolo:26qtppddzz2jx9@it-cluster1.4cwyb2f.mongodb.net/?retryWrites=true&w=majority&appName=IT-Cluster1
|
2 |
+
OPENAI_API_KEY=sk-proj-vc4w7s6gkfwFG8xLBunZT3BlbkFJ8h9zOoyS0OY756vMgBcc
|
backend/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
myenv
|
2 |
+
.pytest_cache
|
backend/Client.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import wave
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torchaudio
|
6 |
+
from vad import EnergyVAD
|
7 |
+
TARGET_SAMPLING_RATE = 16000
|
8 |
+
|
9 |
+
def create_frames(data: bytes, frame_duration: int) -> Tuple[bytes]:
|
10 |
+
frame_size = int(TARGET_SAMPLING_RATE * (frame_duration / 1000))
|
11 |
+
return (data[i:i + frame_size] for i in range(0, len(data), frame_size)), frame_size
|
12 |
+
|
13 |
+
def detect_activity(energies: list):
|
14 |
+
if sum(energies) < len(energies) / 12:
|
15 |
+
return False
|
16 |
+
count = 0
|
17 |
+
for energy in energies:
|
18 |
+
if energy == 1:
|
19 |
+
count += 1
|
20 |
+
if count == 12:
|
21 |
+
return True
|
22 |
+
else:
|
23 |
+
count = 0
|
24 |
+
return False
|
25 |
+
|
26 |
+
class Client:
|
27 |
+
def __init__(self, sid, client_id, username, call_id=None, original_sr=None):
|
28 |
+
self.sid = sid
|
29 |
+
self.client_id = client_id
|
30 |
+
self.username = username,
|
31 |
+
self.call_id = call_id
|
32 |
+
self.buffer = bytearray()
|
33 |
+
self.output_path = self.sid + "_output_audio.wav"
|
34 |
+
self.target_language = None
|
35 |
+
self.original_sr = original_sr
|
36 |
+
self.vad = EnergyVAD(
|
37 |
+
sample_rate=TARGET_SAMPLING_RATE,
|
38 |
+
frame_length=25,
|
39 |
+
frame_shift=20,
|
40 |
+
energy_threshold=0.05,
|
41 |
+
pre_emphasis=0.95,
|
42 |
+
) # PM - Default values given in the docs for this class
|
43 |
+
|
44 |
+
def add_bytes(self, new_bytes):
|
45 |
+
self.buffer += new_bytes
|
46 |
+
|
47 |
+
def resample_and_clear(self):
|
48 |
+
print(f"📥 [ClientAudioBuffer] Writing {len(self.buffer)} bytes to {self.output_path}")
|
49 |
+
with wave.open(self.sid + "_OG.wav", "wb") as wf:
|
50 |
+
wf.setnchannels(1)
|
51 |
+
wf.setsampwidth(2)
|
52 |
+
wf.setframerate(self.original_sr)
|
53 |
+
wf.setnframes(0)
|
54 |
+
wf.setcomptype("NONE", "not compressed")
|
55 |
+
wf.writeframes(self.buffer)
|
56 |
+
waveform, sample_rate = torchaudio.load(self.sid + "_OG.wav")
|
57 |
+
resampler = torchaudio.transforms.Resample(sample_rate, TARGET_SAMPLING_RATE, dtype=waveform.dtype)
|
58 |
+
resampled_waveform = resampler(waveform)
|
59 |
+
self.buffer = bytearray()
|
60 |
+
return resampled_waveform
|
61 |
+
|
62 |
+
def vad_analyse(self, resampled_waveform):
|
63 |
+
torchaudio.save(self.output_path, resampled_waveform, TARGET_SAMPLING_RATE)
|
64 |
+
vad_array = self.vad(resampled_waveform)
|
65 |
+
print(f"VAD OUTPUT: {vad_array}")
|
66 |
+
return detect_activity(vad_array)
|
67 |
+
|
68 |
+
def write_to_file(self, resampled_waveform):
|
69 |
+
torchaudio.save(self.output_path, resampled_waveform, TARGET_SAMPLING_RATE)
|
70 |
+
|
71 |
+
def get_length(self):
|
72 |
+
return len(self.buffer)
|
73 |
+
|
74 |
+
def __del__(self):
|
75 |
+
if len(self.buffer) > 0:
|
76 |
+
print(f"🚨 [ClientAudioBuffer] Buffer not empty for {self.sid} ({len(self.buffer)} bytes)!")
|
77 |
+
if os.path.exists(self.output_path):
|
78 |
+
os.remove(self.output_path)
|
79 |
+
if os.path.exists(self.sid + "_OG.wav"):
|
80 |
+
os.remove(self.sid + "_OG.wav")
|
81 |
+
|
backend/__pycache__/Client.cpython-310.pyc
ADDED
Binary file (3.41 kB). View file
|
|
backend/__pycache__/main.cpython-310.pyc
ADDED
Binary file (6.48 kB). View file
|
|
backend/logging.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: 1
|
2 |
+
disable_existing_loggers: false
|
3 |
+
|
4 |
+
formatters:
|
5 |
+
standard:
|
6 |
+
format: "%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s"
|
7 |
+
|
8 |
+
handlers:
|
9 |
+
console:
|
10 |
+
class: logging.StreamHandler
|
11 |
+
formatter: standard
|
12 |
+
stream: ext://sys.stdout
|
13 |
+
|
14 |
+
loggers:
|
15 |
+
uvicorn:
|
16 |
+
error:
|
17 |
+
propagate: true
|
18 |
+
|
19 |
+
root:
|
20 |
+
level: INFO
|
21 |
+
handlers: [console]
|
22 |
+
propagate: no
|
backend/main.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from operator import itemgetter
|
2 |
+
import os
|
3 |
+
from datetime import datetime
|
4 |
+
import uvicorn
|
5 |
+
from typing import Any, Optional, Tuple, Dict, TypedDict
|
6 |
+
from urllib import parse
|
7 |
+
from uuid import uuid4
|
8 |
+
import logging
|
9 |
+
from fastapi.logger import logger as fastapi_logger
|
10 |
+
import sys
|
11 |
+
|
12 |
+
from fastapi import FastAPI
|
13 |
+
from fastapi.middleware.cors import CORSMiddleware
|
14 |
+
from fastapi import APIRouter, Body, Request, status
|
15 |
+
from pymongo import MongoClient
|
16 |
+
from dotenv import dotenv_values
|
17 |
+
from routes import router as api_router
|
18 |
+
from contextlib import asynccontextmanager
|
19 |
+
import requests
|
20 |
+
|
21 |
+
from typing import List
|
22 |
+
from datetime import date
|
23 |
+
from mongodb.operations.calls import *
|
24 |
+
from mongodb.operations.users import *
|
25 |
+
from mongodb.models.calls import UserCall, UpdateCall
|
26 |
+
# from mongodb.endpoints.calls import *
|
27 |
+
|
28 |
+
from transformers import AutoProcessor, SeamlessM4Tv2Model
|
29 |
+
|
30 |
+
# from seamless_communication.inference import Translator
|
31 |
+
from Client import Client
|
32 |
+
import numpy as np
|
33 |
+
import torch
|
34 |
+
import socketio
|
35 |
+
|
36 |
+
# Configure logger
|
37 |
+
gunicorn_error_logger = logging.getLogger("gunicorn.error")
|
38 |
+
gunicorn_logger = logging.getLogger("gunicorn")
|
39 |
+
uvicorn_access_logger = logging.getLogger("uvicorn.access")
|
40 |
+
|
41 |
+
gunicorn_error_logger.propagate = True
|
42 |
+
gunicorn_logger.propagate = True
|
43 |
+
uvicorn_access_logger.propagate = True
|
44 |
+
|
45 |
+
uvicorn_access_logger.handlers = gunicorn_error_logger.handlers
|
46 |
+
fastapi_logger.handlers = gunicorn_error_logger.handlers
|
47 |
+
|
48 |
+
# sio is the main socket.io entrypoint
|
49 |
+
sio = socketio.AsyncServer(
|
50 |
+
async_mode="asgi",
|
51 |
+
cors_allowed_origins="*",
|
52 |
+
logger=gunicorn_logger,
|
53 |
+
engineio_logger=gunicorn_logger,
|
54 |
+
)
|
55 |
+
# sio.logger.setLevel(logging.DEBUG)
|
56 |
+
socketio_app = socketio.ASGIApp(sio)
|
57 |
+
# app.mount("/", socketio_app)
|
58 |
+
|
59 |
+
config = dotenv_values(".env")
|
60 |
+
|
61 |
+
# Read connection string from environment vars
|
62 |
+
# uri = os.environ['MONGODB_URI']
|
63 |
+
|
64 |
+
# Read connection string from .env file
|
65 |
+
uri = config['MONGODB_URI']
|
66 |
+
|
67 |
+
|
68 |
+
# MongoDB Connection Lifespan Events
|
69 |
+
@asynccontextmanager
|
70 |
+
async def lifespan(app: FastAPI):
|
71 |
+
# startup logic
|
72 |
+
app.mongodb_client = MongoClient(uri)
|
73 |
+
app.database = app.mongodb_client['IT-Cluster1'] #connect to interpretalk primary db
|
74 |
+
try:
|
75 |
+
app.mongodb_client.admin.command('ping')
|
76 |
+
print("MongoDB Connection Established...")
|
77 |
+
except Exception as e:
|
78 |
+
print(e)
|
79 |
+
|
80 |
+
yield
|
81 |
+
|
82 |
+
# shutdown logic
|
83 |
+
print("Closing MongoDB Connection...")
|
84 |
+
app.mongodb_client.close()
|
85 |
+
|
86 |
+
app = FastAPI(lifespan=lifespan, logger=gunicorn_logger)
|
87 |
+
|
88 |
+
# New CORS funcitonality
|
89 |
+
app.add_middleware(
|
90 |
+
CORSMiddleware,
|
91 |
+
allow_origins=["*"], # configured node app port
|
92 |
+
allow_credentials=True,
|
93 |
+
allow_methods=["*"],
|
94 |
+
allow_headers=["*"],
|
95 |
+
)
|
96 |
+
|
97 |
+
app.include_router(api_router) # include routers for user, calls and transcripts operations
|
98 |
+
|
99 |
+
DEBUG = True
|
100 |
+
|
101 |
+
ESCAPE_HATCH_SERVER_LOCK_RELEASE_NAME = "remove_server_lock"
|
102 |
+
|
103 |
+
TARGET_SAMPLING_RATE = 16000
|
104 |
+
MAX_BYTES_BUFFER = 960_000
|
105 |
+
|
106 |
+
print("")
|
107 |
+
print("")
|
108 |
+
print("=" * 18 + " Interpretalk is starting... " + "=" * 18)
|
109 |
+
|
110 |
+
###############################################
|
111 |
+
# Configure socketio server
|
112 |
+
###############################################
|
113 |
+
|
114 |
+
# TODO PM - change this to the actual path
|
115 |
+
# seamless remnant code
|
116 |
+
CLIENT_BUILD_PATH = "../streaming-react-app/dist/"
|
117 |
+
static_files = {
|
118 |
+
"/": CLIENT_BUILD_PATH,
|
119 |
+
"/assets/seamless-db6a2555.svg": {
|
120 |
+
"filename": CLIENT_BUILD_PATH + "assets/seamless-db6a2555.svg",
|
121 |
+
"content_type": "image/svg+xml",
|
122 |
+
},
|
123 |
+
}
|
124 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
125 |
+
processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
|
126 |
+
|
127 |
+
# PM - hardcoding temporarily as my GPU doesnt have enough vram
|
128 |
+
model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to(device)
|
129 |
+
|
130 |
+
|
131 |
+
bytes_data = bytearray()
|
132 |
+
model_name = "seamlessM4T_v2_large"
|
133 |
+
vocoder_name = "vocoder_v2" if model_name == "seamlessM4T_v2_large" else "vocoder_36langs"
|
134 |
+
|
135 |
+
clients = {}
|
136 |
+
rooms = {}
|
137 |
+
|
138 |
+
|
139 |
+
def get_collection_users():
|
140 |
+
return app.database["user_records"]
|
141 |
+
|
142 |
+
def get_collection_calls():
|
143 |
+
return app.database["call_records"]
|
144 |
+
|
145 |
+
|
146 |
+
@app.get("/home/", response_description="Welcome User")
|
147 |
+
def test():
|
148 |
+
return {"message": "Welcome to InterpreTalk!"}
|
149 |
+
|
150 |
+
|
151 |
+
async def send_translated_text(client_id, username, original_text, translated_text, room_id):
|
152 |
+
# print(rooms) # Debugging
|
153 |
+
# print(clients) # Debugging
|
154 |
+
|
155 |
+
data = {
|
156 |
+
"author_id": str(client_id),
|
157 |
+
"author_username": str(username),
|
158 |
+
"original_text": str(original_text),
|
159 |
+
"translated_text": str(translated_text),
|
160 |
+
"timestamp": str(datetime.now())
|
161 |
+
}
|
162 |
+
gunicorn_logger.info("SENDING TRANSLATED TEXT TO CLIENT")
|
163 |
+
await sio.emit("translated_text", data, room=room_id)
|
164 |
+
gunicorn_logger.info("SUCCESSFULLY SEND AUDIO TO FRONTEND")
|
165 |
+
|
166 |
+
|
167 |
+
@sio.on("connect")
|
168 |
+
async def connect(sid, environ):
|
169 |
+
print(f"📥 [event: connected] sid={sid}")
|
170 |
+
query_params = dict(parse.parse_qsl(environ["QUERY_STRING"]))
|
171 |
+
|
172 |
+
client_id = query_params.get("client_id")
|
173 |
+
gunicorn_logger.info(f"📥 [event: connected] sid={sid}, client_id={client_id}")
|
174 |
+
|
175 |
+
# get username to Client Object from DB
|
176 |
+
username = find_name_from_id(get_collection_users(), client_id)
|
177 |
+
|
178 |
+
# sid = socketid, client_id = client specific ID ,always the same for same user
|
179 |
+
clients[sid] = Client(sid, client_id, username)
|
180 |
+
print(clients[sid].username)
|
181 |
+
gunicorn_logger.warning(f"Client connected: {sid}")
|
182 |
+
gunicorn_logger.warning(clients)
|
183 |
+
|
184 |
+
|
185 |
+
@sio.on("disconnect")
|
186 |
+
async def disconnect(sid):
|
187 |
+
gunicorn_logger.debug(f"📤 [event: disconnected] sid={sid}")
|
188 |
+
|
189 |
+
call_id = clients[sid].call_id
|
190 |
+
user_id = clients[sid].client_id
|
191 |
+
target_language = clients[sid].target_language
|
192 |
+
|
193 |
+
clients.pop(sid, None)
|
194 |
+
|
195 |
+
# Perform Key Term Extraction and summarisation
|
196 |
+
try:
|
197 |
+
# Get combined caption field for call record based on call_id
|
198 |
+
key_terms = term_extraction(get_collection_calls(), call_id, user_id, target_language)
|
199 |
+
|
200 |
+
# Perform summarisation based on target language
|
201 |
+
summary_result = summarise(get_collection_calls(), call_id, user_id, target_language)
|
202 |
+
|
203 |
+
except:
|
204 |
+
gunicorn_logger.error(f"📤 [event: term_extraction/summarisation request error] sid={sid}, call={call_id}")
|
205 |
+
|
206 |
+
|
207 |
+
@sio.on("target_language")
|
208 |
+
async def target_language(sid, target_lang):
|
209 |
+
gunicorn_logger.info(f"📥 [event: target_language] sid={sid}, target_lang={target_lang}")
|
210 |
+
clients[sid].target_language = target_lang
|
211 |
+
|
212 |
+
|
213 |
+
@sio.on("call_user")
|
214 |
+
async def call_user(sid, call_id):
|
215 |
+
clients[sid].call_id = call_id
|
216 |
+
gunicorn_logger.info(f"CALL {sid}: entering room {call_id}")
|
217 |
+
rooms[call_id] = rooms.get(call_id, [])
|
218 |
+
if sid not in rooms[call_id] and len(rooms[call_id]) < 2:
|
219 |
+
rooms[call_id].append(sid)
|
220 |
+
sio.enter_room(sid, call_id)
|
221 |
+
else:
|
222 |
+
gunicorn_logger.info(f"CALL {sid}: room {call_id} is full")
|
223 |
+
# await sio.emit("room_full", room=call_id, to=sid)
|
224 |
+
|
225 |
+
# BO - Get call id from dictionary created during socketio connection
|
226 |
+
client_id = clients[sid].client_id
|
227 |
+
|
228 |
+
gunicorn_logger.warning(f"NOW TRYING TO CREATE DB RECORD FOR Caller with ID: {client_id} for call: {call_id}")
|
229 |
+
# BO -> Create Call Record with Caller and call_id field (None for callee, duration, terms..)
|
230 |
+
request_data = {
|
231 |
+
"call_id": str(call_id),
|
232 |
+
"caller_id": str(client_id),
|
233 |
+
"creation_date": str(datetime.now())
|
234 |
+
}
|
235 |
+
|
236 |
+
response = create_calls(get_collection_calls(), request_data)
|
237 |
+
print(response) # BO - print created db call record
|
238 |
+
|
239 |
+
|
240 |
+
@sio.on("audio_config")
|
241 |
+
async def audio_config(sid, sample_rate):
|
242 |
+
clients[sid].original_sr = sample_rate
|
243 |
+
|
244 |
+
|
245 |
+
@sio.on("answer_call")
|
246 |
+
async def answer_call(sid, call_id):
|
247 |
+
|
248 |
+
clients[sid].call_id = call_id
|
249 |
+
gunicorn_logger.info(f"ANSWER {sid}: entering room {call_id}")
|
250 |
+
rooms[call_id] = rooms.get(call_id, [])
|
251 |
+
if sid not in rooms[call_id] and len(rooms[call_id]) < 2:
|
252 |
+
rooms[call_id].append(sid)
|
253 |
+
sio.enter_room(sid, call_id)
|
254 |
+
else:
|
255 |
+
gunicorn_logger.info(f"ANSWER {sid}: room {call_id} is full")
|
256 |
+
# await sio.emit("room_full", room=call_id, to=sid)
|
257 |
+
|
258 |
+
|
259 |
+
# BO - Get call id from dictionary created during socketio connection
|
260 |
+
client_id = clients[sid].client_id
|
261 |
+
|
262 |
+
# BO -> Update Call Record with Callee field based on call_id
|
263 |
+
gunicorn_logger.warning(f"NOW UPDATING MongoDB RECORD FOR Caller with ID: {client_id} for call: {call_id}")
|
264 |
+
# BO -> Create Call Record with callee_id field (None for callee, duration, terms..)
|
265 |
+
request_data = {
|
266 |
+
"callee_id": client_id
|
267 |
+
}
|
268 |
+
|
269 |
+
response = update_calls(get_collection_calls(), call_id, request_data)
|
270 |
+
print(response) # BO - print created db call record
|
271 |
+
|
272 |
+
|
273 |
+
@sio.on("incoming_audio")
|
274 |
+
async def incoming_audio(sid, data, call_id):
|
275 |
+
try:
|
276 |
+
clients[sid].add_bytes(data)
|
277 |
+
|
278 |
+
if clients[sid].get_length() >= MAX_BYTES_BUFFER:
|
279 |
+
gunicorn_logger.info('Buffer full, now outputting...')
|
280 |
+
output_path = clients[sid].output_path
|
281 |
+
resampled_audio = clients[sid].resample_and_clear()
|
282 |
+
vad_result = clients[sid].vad_analyse(resampled_audio)
|
283 |
+
# source lang is speakers tgt language 😃
|
284 |
+
src_lang = clients[sid].target_language
|
285 |
+
|
286 |
+
if vad_result:
|
287 |
+
gunicorn_logger.info('Speech detected, now processing audio.....')
|
288 |
+
tgt_sid = next(id for id in rooms[call_id] if id != sid)
|
289 |
+
tgt_lang = clients[tgt_sid].target_language
|
290 |
+
# following example from https://github.com/facebookresearch/seamless_communication/blob/main/docs/m4t/README.md#transformers-usage
|
291 |
+
output_tokens = processor(audios=resampled_audio, src_lang=src_lang, return_tensors="pt", sampling_rate=TARGET_SAMPLING_RATE).to(device)
|
292 |
+
model_output = model.generate(**output_tokens, tgt_lang=src_lang, generate_speech=False)[0].tolist()[0]
|
293 |
+
asr_text = processor.decode(model_output, skip_special_tokens=True)
|
294 |
+
print(f"ASR TEXT = {asr_text}")
|
295 |
+
# ASR TEXT => ORIGINAL TEXT
|
296 |
+
|
297 |
+
if src_lang != tgt_lang:
|
298 |
+
t2t_tokens = processor(text=asr_text, src_lang=src_lang, tgt_lang=tgt_lang, return_tensors="pt").to(device)
|
299 |
+
translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0]
|
300 |
+
translated_text = processor.decode(translated_data, skip_special_tokens=True)
|
301 |
+
print(f"TRANSLATED TEXT = {translated_text}")
|
302 |
+
else:
|
303 |
+
# PM - both users have same language selected, no need to translate
|
304 |
+
translated_text = asr_text
|
305 |
+
|
306 |
+
# PM - text_output is a list with 1 string
|
307 |
+
await send_translated_text(clients[sid].client_id, clients[sid].username, asr_text, translated_text, call_id)
|
308 |
+
|
309 |
+
# BO -> send translated_text to mongodb as caption record update based on call_id
|
310 |
+
await send_captions(clients[sid].client_id, clients[sid].username, asr_text, translated_text, call_id)
|
311 |
+
|
312 |
+
except Exception as e:
|
313 |
+
gunicorn_logger.error(f"Error in incoming_audio: {e.with_traceback()}")
|
314 |
+
|
315 |
+
|
316 |
+
async def send_captions(client_id, username, original_text, translated_text, call_id):
|
317 |
+
# BO -> Update Call Record with Callee field based on call_id
|
318 |
+
print(f"Now updating Caption field in call record for Caller with ID: {client_id} for call: {call_id}")
|
319 |
+
|
320 |
+
data = {
|
321 |
+
"author_id": str(client_id),
|
322 |
+
"author_username": str(username),
|
323 |
+
"original_text": str(original_text),
|
324 |
+
"translated_text": str(translated_text),
|
325 |
+
"timestamp": str(datetime.now())
|
326 |
+
}
|
327 |
+
|
328 |
+
response = update_captions(get_collection_calls(), get_collection_users(), call_id, data)
|
329 |
+
return response
|
330 |
+
|
331 |
+
|
332 |
+
app.mount("/", socketio_app)
|
333 |
+
|
334 |
+
|
335 |
+
if __name__ == '__main__':
|
336 |
+
uvicorn.run("main:app", host='0.0.0.0', port=7860, log_level="info")
|
337 |
+
|
338 |
+
|
339 |
+
# Running in Docker Container
|
340 |
+
if __name__ != "__main__":
|
341 |
+
fastapi_logger.setLevel(gunicorn_logger.level)
|
342 |
+
else:
|
343 |
+
fastapi_logger.setLevel(logging.DEBUG)
|
backend/models/Seamless/vad_s2st_sc_24khz_main.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
agent_class: seamless_communication.streaming.agents.seamless_s2st.SeamlessS2STDualVocoderVADAgent
|
2 |
+
monotonic_decoder_model_name: seamless_streaming_monotonic_decoder
|
3 |
+
unity_model_name: seamless_streaming_unity
|
4 |
+
sentencepiece_model: spm_256k_nllb100.model
|
5 |
+
|
6 |
+
task: s2st
|
7 |
+
tgt_lang: "eng"
|
8 |
+
min_unit_chunk_size: 50
|
9 |
+
decision_threshold: 0.7
|
10 |
+
no_early_stop: True
|
11 |
+
block_ngrams: True
|
12 |
+
vocoder_name: vocoder_v2
|
13 |
+
expr_vocoder_name: vocoder_pretssel
|
14 |
+
gated_model_dir: .
|
15 |
+
expr_vocoder_gain: 3.0
|
16 |
+
upstream_idx: 1
|
17 |
+
wav2vec_yaml: wav2vec.yaml
|
18 |
+
min_starting_wait_w2vbert: 192
|
19 |
+
|
20 |
+
config_yaml: cfg_fbank_u2t.yaml
|
21 |
+
upstream_idx: 1
|
22 |
+
detokenize_only: True
|
23 |
+
device: cuda:0
|
24 |
+
max_len_a: 0
|
25 |
+
max_len_b: 1000
|
backend/models/SeamlessStreaming/vad_s2st_sc_main.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
agent_class: seamless_communication.streaming.agents.seamless_streaming_s2st.SeamlessStreamingS2STJointVADAgent
|
2 |
+
monotonic_decoder_model_name: seamless_streaming_monotonic_decoder
|
3 |
+
unity_model_name: seamless_streaming_unity
|
4 |
+
sentencepiece_model: spm_256k_nllb100.model
|
5 |
+
|
6 |
+
task: s2st
|
7 |
+
tgt_lang: "eng"
|
8 |
+
min_unit_chunk_size: 50
|
9 |
+
decision_threshold: 0.7
|
10 |
+
no_early_stop: True
|
11 |
+
block_ngrams: True
|
12 |
+
vocoder_name: vocoder_v2
|
13 |
+
wav2vec_yaml: wav2vec.yaml
|
14 |
+
min_starting_wait_w2vbert: 192
|
15 |
+
|
16 |
+
config_yaml: cfg_fbank_u2t.yaml
|
17 |
+
upstream_idx: 1
|
18 |
+
detokenize_only: True
|
19 |
+
device: cuda:0
|
20 |
+
max_len_a: 0
|
21 |
+
max_len_b: 1000
|
backend/mongodb/endpoints/__pycache__/calls.cpython-310.pyc
ADDED
Binary file (4.74 kB). View file
|
|
backend/mongodb/endpoints/__pycache__/users.cpython-310.pyc
ADDED
Binary file (2.43 kB). View file
|
|
backend/mongodb/endpoints/calls.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Body, Request, status, HTTPException
|
2 |
+
from typing import List
|
3 |
+
from datetime import date
|
4 |
+
|
5 |
+
import sys
|
6 |
+
|
7 |
+
from ..operations import calls as calls
|
8 |
+
from ..models.calls import UserCaptions, UserCall, UpdateCall
|
9 |
+
from ..endpoints.users import get_collection_users
|
10 |
+
|
11 |
+
router = APIRouter(prefix="/call",
|
12 |
+
tags=["Calls"])
|
13 |
+
|
14 |
+
def get_collection_calls(request: Request):
|
15 |
+
try:
|
16 |
+
return request.app.database["call_records"]
|
17 |
+
# return request.app.database["call_test"]
|
18 |
+
except:
|
19 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Unable to find call records Database.")
|
20 |
+
|
21 |
+
@router.post("/create-call", response_description="Create a new user call record", status_code=status.HTTP_201_CREATED, response_model=UserCall)
|
22 |
+
async def create_calls(request: Request, user_calls: UserCall = Body(...)):
|
23 |
+
collection = get_collection_calls(request)
|
24 |
+
return calls.create_calls(collection, user_calls)
|
25 |
+
|
26 |
+
@router.get("/list-call", response_description="List all existing call records", response_model=List[UserCall])
|
27 |
+
async def list_calls(request: Request, limit: int):
|
28 |
+
collection = get_collection_calls(request)
|
29 |
+
return calls.list_calls(collection, 100)
|
30 |
+
|
31 |
+
@router.get("/find-call/{call_id}", response_description="Find user's calls based on User ID", response_model=UserCall)
|
32 |
+
async def find_call(request: Request, call_id: str):
|
33 |
+
collection = get_collection_calls(request)
|
34 |
+
return calls.find_call(collection, call_id)
|
35 |
+
|
36 |
+
@router.get("/find-user-calls/{user_id}", response_description="Find user's calls based on User ID", response_model=List[UserCall])
|
37 |
+
async def find_user_calls(request: Request, user_id: str):
|
38 |
+
collection = get_collection_calls(request)
|
39 |
+
return calls.find_user_calls(collection, user_id)
|
40 |
+
|
41 |
+
@router.get("/get-captions/{user_id}", response_description="Find user's calls based on User ID")
|
42 |
+
async def get_caption_text(request: Request, call_id: str, user_id: str):
|
43 |
+
collection = get_collection_calls(request)
|
44 |
+
return calls.get_caption_text(collection, call_id, user_id)
|
45 |
+
|
46 |
+
'''Key terms list can have variable length -> using POST request over GET'''
|
47 |
+
@router.post("/find-term/", response_description="Find calls based on key term list", response_model=List[UserCall])
|
48 |
+
async def list_transcripts_by_key_terms(request: Request, key_terms: List[str]):
|
49 |
+
collection = get_collection_calls(request)
|
50 |
+
return calls.list_transcripts_by_key_terms(collection, key_terms)
|
51 |
+
|
52 |
+
@router.get("/find-date/{start_date}/{end_date}", response_description="Find calls based on date ranges", response_model=List[UserCall])
|
53 |
+
async def list_transcripts_by_dates(request: Request, start_date: str, end_date: str):
|
54 |
+
collection = get_collection_calls(request)
|
55 |
+
return calls.list_transcripts_by_dates(collection, start_date, end_date)
|
56 |
+
|
57 |
+
@router.get("/find-duration/{min_len}/{max_len}", response_description="Find calls based on call duration in minutes", response_model=List[UserCall])
|
58 |
+
async def list_transcripts_by_duration(request: Request, min_len: int, max_len: int):
|
59 |
+
collection = get_collection_calls(request)
|
60 |
+
return calls.list_transcripts_by_duration(collection, min_len, max_len)
|
61 |
+
|
62 |
+
@router.put("/update-call/{call_id}", response_description="Update an existing call", response_model=UpdateCall)
|
63 |
+
async def update_calls(request: Request, call_id: str, user_calls: UpdateCall = Body(...)):
|
64 |
+
collection = get_collection_calls(request)
|
65 |
+
return calls.update_calls(collection, call_id, user_calls)
|
66 |
+
|
67 |
+
@router.put("/update-captions/{call_id}", response_description="Update an existing call", response_model=UpdateCall)
|
68 |
+
async def update_captions(request: Request, call_id: str, user_calls: UserCaptions = Body(...)):
|
69 |
+
call_collection = get_collection_calls(request)
|
70 |
+
user_collection = get_collection_users(request)
|
71 |
+
return calls.update_captions(call_collection, user_collection, call_id, user_calls)
|
72 |
+
|
73 |
+
@router.delete("/delete-call/{call_id}", response_description="Delete a call by its id")
|
74 |
+
async def delete_call(request: Request, call_id: str):
|
75 |
+
collection = get_collection_calls(request)
|
76 |
+
return calls.delete_calls(collection, call_id)
|
77 |
+
|
78 |
+
@router.get("/full-text-search/{query}", response_description="Perform full text search on caption fields", response_model=List[UserCall])
|
79 |
+
async def full_text_search(request: Request, query: str):
|
80 |
+
collection = get_collection_calls(request)
|
81 |
+
return calls.full_text_search(collection, query)
|
82 |
+
|
83 |
+
@router.get("/fuzzy-search/{user_id}/{query}", response_description="Perform fuzzy text search on caption fields", response_model=List[UserCall])
|
84 |
+
async def fuzzy_search(request: Request, user_id: str, query: str):
|
85 |
+
collection = get_collection_calls(request)
|
86 |
+
return calls.fuzzy_search(collection, user_id, query)
|
87 |
+
|
88 |
+
@router.get("/summarise/{call_id}/{user_id}/{target_language}", response_description="Perform gpt-3.5 summarisation on call_id")
|
89 |
+
async def summarise(request: Request, call_id: str, user_id: str, target_language: str):
|
90 |
+
collection = get_collection_calls(request)
|
91 |
+
return calls.summarise(collection, call_id, user_id, target_language)
|
92 |
+
|
93 |
+
@router.get("/term-extraction/{call_id}/{user_id}/{target_language}", response_description="Perform key term extraction on call record")
|
94 |
+
async def term_extraction(request: Request, call_id: str, user_id: str, target_language: str):
|
95 |
+
collection = get_collection_calls(request)
|
96 |
+
return calls.term_extraction(collection, call_id, user_id, target_language)
|
backend/mongodb/endpoints/users.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Body, Request, status, HTTPException
|
2 |
+
from typing import List
|
3 |
+
import sys
|
4 |
+
from ..models.users import User, UpdateUser
|
5 |
+
from ..operations import users as users
|
6 |
+
|
7 |
+
router = APIRouter(prefix="/user",
|
8 |
+
tags=["User"])
|
9 |
+
|
10 |
+
def get_collection_users(request: Request):
|
11 |
+
db = request.app.database["user_records"]
|
12 |
+
return db
|
13 |
+
|
14 |
+
@router.post("/", response_description="Create a new user", status_code=status.HTTP_201_CREATED, response_model=User)
|
15 |
+
async def create_user(request: Request, user: User = Body(...)):
|
16 |
+
collection = get_collection_users(request)
|
17 |
+
return users.create_user(collection, user)
|
18 |
+
|
19 |
+
@router.get("/", response_description="List users", response_model=List[User])
|
20 |
+
async def list_users(request: Request):
|
21 |
+
collection = get_collection_users(request)
|
22 |
+
return users.list_users(collection, 100)
|
23 |
+
|
24 |
+
@router.put("/{user_id}", response_description="Update a User", response_model=UpdateUser)
|
25 |
+
async def update_user(request: Request, user_id: str, user: UpdateUser = Body(...)):
|
26 |
+
collection = get_collection_users(request)
|
27 |
+
return users.update_user(collection, user_id, user)
|
28 |
+
|
29 |
+
@router.get("/{user_id}", response_description="Get a single user by id", response_model=User)
|
30 |
+
async def find_user(request: Request, user_id: str):
|
31 |
+
collection = get_collection_users(request)
|
32 |
+
return users.find_user(collection, user_id)
|
33 |
+
|
34 |
+
@router.get("/find-name-id/{user_id}", response_description="Get a username from user id")
|
35 |
+
async def find_name_from_id(request: Request, user_id: str):
|
36 |
+
collection = get_collection_users(request)
|
37 |
+
return users.find_name_from_id(collection, user_id)
|
38 |
+
|
39 |
+
@router.get("/name/{user_name}", response_description="Get a single user by name", response_model=User)
|
40 |
+
async def find_user_name(request: Request, name: str):
|
41 |
+
collection = get_collection_users(request)
|
42 |
+
return users.find_user_name(collection, name)
|
43 |
+
|
44 |
+
@router.get("/email/{email_addr}", response_description="Get a single user by email", response_model=User)
|
45 |
+
async def find_user_email(request: Request, email: str):
|
46 |
+
collection = get_collection_users(request)
|
47 |
+
return users.find_user_email(collection, email)
|
48 |
+
|
49 |
+
@router.delete("/{user_id}", response_description="Delete a user")
|
50 |
+
async def delete_user(request: Request, user_id:str):
|
51 |
+
collection = get_collection_users(request)
|
52 |
+
return users.delete_user(collection, user_id)
|
53 |
+
|
backend/mongodb/models/__pycache__/calls.cpython-310.pyc
ADDED
Binary file (3.09 kB). View file
|
|
backend/mongodb/models/__pycache__/users.cpython-310.pyc
ADDED
Binary file (1.73 kB). View file
|
|
backend/mongodb/models/calls.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uuid
|
2 |
+
from typing import List, Dict, Optional
|
3 |
+
from datetime import datetime
|
4 |
+
from pydantic import BaseModel, Field, PrivateAttr
|
5 |
+
import sys
|
6 |
+
|
7 |
+
|
8 |
+
''' Class for storing captions generated by SeamlessM4T'''
|
9 |
+
class UserCaptions(BaseModel):
|
10 |
+
_id: uuid.UUID = PrivateAttr(default_factory=uuid.uuid4) # private attr not included in http calls
|
11 |
+
author_id: Optional[str] = None
|
12 |
+
author_username: Optional[str] = None
|
13 |
+
original_text: str
|
14 |
+
translated_text: str
|
15 |
+
timestamp: datetime = Field(default_factory=datetime.now)
|
16 |
+
|
17 |
+
class Config:
|
18 |
+
populate_by_name = True
|
19 |
+
json_schema_extra = {
|
20 |
+
"example": {
|
21 |
+
"author_id": "gLZrfTwXyLUPB3eT7xT2HZnZiZT2",
|
22 |
+
"author_username": "shamzino",
|
23 |
+
"original_text": "eng: This is original_text english text",
|
24 |
+
"translated_text": "spa: este es el texto traducido al español",
|
25 |
+
"timestamp": "2024-03-28T16:15:50.956055",
|
26 |
+
|
27 |
+
}
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
'''Class for storing past call records from users'''
|
32 |
+
class UserCall(BaseModel):
|
33 |
+
_id: uuid.UUID = PrivateAttr(default_factory=uuid.uuid4)
|
34 |
+
call_id: Optional[str] = None
|
35 |
+
caller_id: Optional[str] = None
|
36 |
+
callee_id: Optional[str] = None
|
37 |
+
creation_date: datetime = Field(default_factory=datetime.now, alias="date")
|
38 |
+
duration: Optional[int] = None # milliseconds
|
39 |
+
captions: Optional[List[UserCaptions]] = None
|
40 |
+
key_terms: Optional[dict] = None
|
41 |
+
summaries: Optional[dict] = None
|
42 |
+
|
43 |
+
|
44 |
+
class Config:
|
45 |
+
populate_by_name = True
|
46 |
+
json_schema_extra = {
|
47 |
+
"example": {
|
48 |
+
"call_id": "65eef930e9abd3b1e3506906",
|
49 |
+
"caller_id": "65ede65b6d246e52aaba9d4f",
|
50 |
+
"callee_id": "65edda944340ac84c1f00758",
|
51 |
+
"duration": 360,
|
52 |
+
"captions": [{"author_id": "gLZrfTwXyLUPB3eT7xT2HZnZiZT2", "author_username": "shamzino", "original_text": "eng: This is original_text english text", "translated_text": "spa: este es el texto traducido al español", "timestamp": "2024-03-28T16:15:50.956055"},
|
53 |
+
{"author_id": "g7pR1qCibzQf5mDP9dGtcoWeEc92", "author_username": "benjino", "original_text": "eng: This is source english text", "translated_text": "spa: este es el texto fuente al español", "timestamp": "2024-03-28T16:16:20.34625"}],
|
54 |
+
"key_terms": {"gLZrfTwXyLUPB3eT7xT2HZnZiZT2": ["original_text", "source", "english", "text"], "g7pR1qCibzQf5mDP9dGtcoWeEc92": ["translated_text", "destination", "spanish", "text"]},
|
55 |
+
"summaries": {"gLZrfTwXyLUPB3eT7xT2HZnZiZT2": "This is a short test on lanuguage translation", "65edda944340ac84c1f00758": "Esta es una breve prueba sobre traducción de idiomas."}
|
56 |
+
}
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
''' Class for updating User Call record'''
|
61 |
+
class UpdateCall(BaseModel):
|
62 |
+
call_id: Optional[str] = None
|
63 |
+
caller_id: Optional[str] = None
|
64 |
+
callee_id: Optional[str] = None
|
65 |
+
duration: Optional[int] = None
|
66 |
+
captions: Optional[List[UserCaptions]] = None
|
67 |
+
key_terms: Optional[List[str]] = None
|
68 |
+
|
69 |
+
class Config:
|
70 |
+
populate_by_name = True
|
71 |
+
json_schema_extra = {
|
72 |
+
"example": {
|
73 |
+
"duration": "500"
|
74 |
+
}
|
75 |
+
}
|
backend/mongodb/models/users.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uuid
|
2 |
+
from typing import List, Optional
|
3 |
+
from pydantic import BaseModel, Field, SecretStr, PrivateAttr
|
4 |
+
from pydantic.networks import EmailStr
|
5 |
+
|
6 |
+
|
7 |
+
'''Class for user model used to relate users to past calls'''
|
8 |
+
class User(BaseModel):
|
9 |
+
_id: uuid.UUID = PrivateAttr(default_factory=uuid.uuid4) # private attr not included in http calls
|
10 |
+
user_id: str
|
11 |
+
name: str
|
12 |
+
email: EmailStr = Field(unique=True, index=True)
|
13 |
+
# password: SecretStr
|
14 |
+
call_ids: Optional[List[str]] = None
|
15 |
+
|
16 |
+
class Config:
|
17 |
+
populate_by_name = True
|
18 |
+
json_schema_extra = {
|
19 |
+
"example": {
|
20 |
+
"user_id": "65ede65b6d246e52aaba9d4f",
|
21 |
+
"name": "benjolo",
|
22 |
+
"email": "benjolounchained@gmail.com",
|
23 |
+
"call_ids": ["65e205ced1be3a22854ff300", "65df8c3eba9c7c2ed1b20e85"]
|
24 |
+
}
|
25 |
+
}
|
26 |
+
|
27 |
+
'''Class for updating user records'''
|
28 |
+
class UpdateUser(BaseModel):
|
29 |
+
user_id: Optional[str] = None
|
30 |
+
name: Optional[str] = None
|
31 |
+
email: Optional[EmailStr] = None
|
32 |
+
''' To decode use -> SecretStr("abc").get_secret_value()'''
|
33 |
+
# password: Optional[SecretStr]
|
34 |
+
call_ids: Optional[List[str]] = None
|
35 |
+
|
36 |
+
class Config:
|
37 |
+
populate_by_name = True
|
38 |
+
json_schema_extra = {
|
39 |
+
"example": {
|
40 |
+
"email": "benjolounchained21@gmail.com",
|
41 |
+
"call_ids": ["65e205ced1be3a22854ff300", "65df8c3eba9c7c2ed1b20e85", "65eef930e9abd3b1e3506906"]
|
42 |
+
}
|
43 |
+
}
|
44 |
+
|
backend/mongodb/operations/__pycache__/calls.cpython-310.pyc
ADDED
Binary file (6.61 kB). View file
|
|
backend/mongodb/operations/__pycache__/users.cpython-310.pyc
ADDED
Binary file (2.93 kB). View file
|
|
backend/mongodb/operations/calls.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import Body, Request, HTTPException, status
|
2 |
+
from fastapi.encoders import jsonable_encoder
|
3 |
+
import sys
|
4 |
+
from ..models.calls import UpdateCall, UserCall, UserCaptions
|
5 |
+
from ..operations.users import *
|
6 |
+
from utils.text_rank import extract_terms
|
7 |
+
from openai import OpenAI
|
8 |
+
|
9 |
+
from time import sleep
|
10 |
+
import os
|
11 |
+
from dotenv import dotenv_values
|
12 |
+
|
13 |
+
# Used within calls to create call record in main.py
|
14 |
+
def create_calls(collection, user: UserCall = Body(...)):
|
15 |
+
calls = jsonable_encoder(user)
|
16 |
+
new_calls = collection.insert_one(calls)
|
17 |
+
created_calls = collection.find_one({"_id": new_calls.inserted_id})
|
18 |
+
|
19 |
+
return created_calls
|
20 |
+
|
21 |
+
def list_calls(collection, limit: int):
|
22 |
+
try:
|
23 |
+
calls = collection.find(limit = limit)
|
24 |
+
return list(calls)
|
25 |
+
except:
|
26 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"No existing call records yet.")
|
27 |
+
|
28 |
+
|
29 |
+
'''Finding calls based on call id'''
|
30 |
+
def find_call(collection, call_id: str):
|
31 |
+
user_calls = collection.find_one({"call_id": call_id})
|
32 |
+
if user_calls is not None:
|
33 |
+
return user_calls
|
34 |
+
else:
|
35 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call with ID: '{call_id}' not found.")
|
36 |
+
|
37 |
+
|
38 |
+
'''Finding calls based on user id'''
|
39 |
+
def find_user_calls(collection, user_id: str):
|
40 |
+
user_calls = list(collection.find({"$or": [{"caller_id": user_id}, {"callee_id": user_id}]})) # match on caller or callee ID
|
41 |
+
if len(user_calls):
|
42 |
+
return user_calls
|
43 |
+
else:
|
44 |
+
return [] # return empty list if no existing calls for TranscriptView frontend component
|
45 |
+
|
46 |
+
|
47 |
+
'''Finding calls based on key terms list'''
|
48 |
+
def list_transcripts_by_key_terms(collection, key_terms_list: list[str] = Body(...)):
|
49 |
+
key_terms_list = jsonable_encoder(key_terms_list)
|
50 |
+
|
51 |
+
call_records = list(collection.find({"key_terms": {"$in": key_terms_list}}, {'_id': 0})) # exclude returning ObjectID in find()
|
52 |
+
|
53 |
+
# Check if any call records were returned
|
54 |
+
if len(call_records):
|
55 |
+
return call_records
|
56 |
+
else:
|
57 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call with key terms: '{key_terms_list}' not found!")
|
58 |
+
|
59 |
+
|
60 |
+
'''Finding calls based on date ranges'''
|
61 |
+
def list_transcripts_by_dates(collection, start_date: str, end_date: str):
|
62 |
+
# print(start_date, end_date)
|
63 |
+
|
64 |
+
# Convert strings to date string in YYYY-MM-ddT00:00:00 format
|
65 |
+
start_date = f'{start_date}T00:00:00'
|
66 |
+
end_date = f'{end_date}T00:00:00'
|
67 |
+
|
68 |
+
call_records = list(collection.find({"date":{"$gte": start_date, "$lte": end_date}}))
|
69 |
+
|
70 |
+
if len(call_records):
|
71 |
+
return call_records
|
72 |
+
else:
|
73 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call with creation date between: '{start_date} - {end_date}' not found!")
|
74 |
+
|
75 |
+
|
76 |
+
'''Finding calls based on call lengths'''
|
77 |
+
def list_transcripts_by_duration(collection, min_len: int, max_len: int):
|
78 |
+
|
79 |
+
call_records = list(collection.find({"duration":{"$gte": min_len, "$lte": max_len}}))
|
80 |
+
|
81 |
+
if len(call_records):
|
82 |
+
return call_records
|
83 |
+
else:
|
84 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call with duration between: '{min_len} - {max_len}' milliseconds not found!")
|
85 |
+
|
86 |
+
|
87 |
+
def update_calls(collection, call_id: str, calls: UpdateCall = Body(...)):
|
88 |
+
# calls = {k: v for k, v in calls.model_dump().items() if v is not None} #loop in the dict
|
89 |
+
calls = {k: v for k, v in calls.items() if v is not None} #loop in the dict
|
90 |
+
print(calls)
|
91 |
+
|
92 |
+
if len(calls) >= 1:
|
93 |
+
update_result = collection.update_one({"call_id": call_id}, {"$set": calls})
|
94 |
+
|
95 |
+
if update_result.modified_count == 0:
|
96 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call not updated!")
|
97 |
+
|
98 |
+
if (existing_item := collection.find_one({"call_id": call_id})) is not None:
|
99 |
+
return existing_item
|
100 |
+
|
101 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call not found!")
|
102 |
+
|
103 |
+
|
104 |
+
def update_captions(call_collection, user_collection, call_id: str, captions: UserCaptions = Body(...)):
|
105 |
+
# captions = {k: v for k, v in calls.model_dump().items() if v is not None}
|
106 |
+
captions = {k: v for k, v in captions.items() if v is not None}
|
107 |
+
# print(captions)
|
108 |
+
|
109 |
+
# index user_id from caption object
|
110 |
+
userID = captions["author_id"]
|
111 |
+
# print(userID)
|
112 |
+
|
113 |
+
# use user id to get user name
|
114 |
+
username = find_name_from_id(user_collection, userID)
|
115 |
+
# print(username)
|
116 |
+
|
117 |
+
# add user name to captions json/object
|
118 |
+
captions["author_username"] = username
|
119 |
+
# print(captions)
|
120 |
+
|
121 |
+
if len(captions) >= 1:
|
122 |
+
update_result = call_collection.update_one({"call_id": call_id},
|
123 |
+
{"$push": {"captions": captions}})
|
124 |
+
|
125 |
+
if update_result.modified_count == 0:
|
126 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Captions not updated!")
|
127 |
+
|
128 |
+
if (existing_item := call_collection.find_one({"call_id": call_id})) is not None:
|
129 |
+
return existing_item
|
130 |
+
|
131 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Captions not found!")
|
132 |
+
|
133 |
+
|
134 |
+
def delete_calls(collection, call_id: str):
|
135 |
+
deleted_calls = collection.delete_one({"call_id": call_id})
|
136 |
+
|
137 |
+
if deleted_calls.deleted_count == 1:
|
138 |
+
return f"Call deleted sucessfully!"
|
139 |
+
|
140 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call not found!")
|
141 |
+
|
142 |
+
|
143 |
+
# def get_caption_text(collection, call_id):
|
144 |
+
# call_record = find_call((collection), call_id)
|
145 |
+
|
146 |
+
# try: # Check if call has any captions first
|
147 |
+
# caption_records = call_record['captions']
|
148 |
+
# except KeyError:
|
149 |
+
# return None
|
150 |
+
|
151 |
+
# # iterate through caption embedded document and store original text
|
152 |
+
# combined_text = [caption['original_text'] for caption in caption_records]
|
153 |
+
|
154 |
+
# return " ".join(combined_text)
|
155 |
+
|
156 |
+
def get_caption_text(collection, call_id, user_id):
|
157 |
+
call_record = find_call((collection), call_id)
|
158 |
+
|
159 |
+
try: # Check if call has any captions first
|
160 |
+
caption_records = call_record['captions']
|
161 |
+
except KeyError:
|
162 |
+
return None
|
163 |
+
|
164 |
+
# iterate through caption embedded document and store original text
|
165 |
+
# combined_text = [caption['original_text'] for caption in caption_records]
|
166 |
+
|
167 |
+
combined_text = []
|
168 |
+
|
169 |
+
for caption_segment in caption_records:
|
170 |
+
if caption_segment['author_id'] == user_id:
|
171 |
+
combined_text.append(caption_segment['original_text'])
|
172 |
+
else:
|
173 |
+
combined_text.append(caption_segment['translated_text'])
|
174 |
+
|
175 |
+
return " ".join(combined_text)
|
176 |
+
|
177 |
+
|
178 |
+
# standard exact match based full text search
|
179 |
+
def full_text_search(collection, query):
|
180 |
+
|
181 |
+
# drop any existing indexes and create new one
|
182 |
+
collection.drop_indexes()
|
183 |
+
collection.create_index([('captions.original_text', 'text'), ('captions.tranlated_text', 'text')],
|
184 |
+
name='captions')
|
185 |
+
|
186 |
+
# print(collection.index_information())
|
187 |
+
|
188 |
+
results = list(collection.find({"$text": {"$search": query}}))
|
189 |
+
return results
|
190 |
+
|
191 |
+
# approximate string matching
|
192 |
+
def fuzzy_search(collection, user_id, query):
|
193 |
+
|
194 |
+
user_calls = collection.find({"$or": [{"caller_id": user_id}, {"callee_id": user_id}]})
|
195 |
+
print("USER CALLS:", user_calls)
|
196 |
+
|
197 |
+
print("COLLECTION:", collection)
|
198 |
+
|
199 |
+
# drop any existing indexes and create new one
|
200 |
+
collection.drop_indexes()
|
201 |
+
collection.create_index([('captions.original_text', 'text'), ('captions.tranlated_text', 'text')],
|
202 |
+
name='captions')
|
203 |
+
|
204 |
+
|
205 |
+
pipeline = [
|
206 |
+
{
|
207 |
+
"$search": {
|
208 |
+
"text": {
|
209 |
+
"query": query,
|
210 |
+
"path": {"wildcard": "*"},
|
211 |
+
"fuzzy": {}
|
212 |
+
}
|
213 |
+
}
|
214 |
+
}
|
215 |
+
]
|
216 |
+
|
217 |
+
collection_results = list(collection.aggregate(pipeline))
|
218 |
+
|
219 |
+
# add all users records to output
|
220 |
+
records = []
|
221 |
+
|
222 |
+
for doc in collection_results:
|
223 |
+
if doc['caller_id'] == user_id or doc['callee_id'] == user_id:
|
224 |
+
records.append(doc)
|
225 |
+
|
226 |
+
print(records)
|
227 |
+
|
228 |
+
return records
|
229 |
+
|
230 |
+
|
231 |
+
def summarise(collection, call_id, user_id, target_language):
|
232 |
+
# client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
233 |
+
|
234 |
+
config = dotenv_values(".env")
|
235 |
+
client = OpenAI(api_key=config["OPENAI_API_KEY"])
|
236 |
+
|
237 |
+
# get caption text using call_id
|
238 |
+
caption_text = get_caption_text(collection, call_id, user_id)
|
239 |
+
|
240 |
+
chat_completion = client.chat.completions.create(
|
241 |
+
messages=[
|
242 |
+
{
|
243 |
+
"role": "user",
|
244 |
+
"content": f"The following is an extract from a call transcript. Rewrite this as a structured, clear summary in {target_language}. \
|
245 |
+
\n\Call Transcript: \"\"\"\n{caption_text}\n\"\"\"\n"
|
246 |
+
}
|
247 |
+
],
|
248 |
+
model="gpt-3.5-turbo",
|
249 |
+
)
|
250 |
+
|
251 |
+
# Gpt-3.5 turbo has 4096 token limit -> request will fail if exceeded
|
252 |
+
try:
|
253 |
+
result = chat_completion.choices[0].message.content.split(":")[1].strip() # parse summary
|
254 |
+
except:
|
255 |
+
return None
|
256 |
+
|
257 |
+
# BO - add result to mongodb -> should be done asynchronously
|
258 |
+
# summary_payload = {"summaries": {user_id: result}}
|
259 |
+
|
260 |
+
update_result = collection.update_one({"call_id": call_id}, {"$set": {f"summaries.{user_id}": result}})
|
261 |
+
|
262 |
+
if update_result.modified_count == 0:
|
263 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call not updated!")
|
264 |
+
|
265 |
+
# if (existing_item := collection.find_one({"call_id": call_id})) is not None:
|
266 |
+
# print(existing_item)
|
267 |
+
|
268 |
+
return result
|
269 |
+
|
270 |
+
|
271 |
+
def term_extraction(collection, call_id, user_id, target_language):
|
272 |
+
|
273 |
+
combined_text = get_caption_text(collection, call_id, user_id)
|
274 |
+
|
275 |
+
if len(combined_text) > 50: # > min_caption_length: -> poor term extraction on short transcripts
|
276 |
+
|
277 |
+
# Extract Key Terms from Concatenated Caption Field
|
278 |
+
key_terms = extract_terms(combined_text, target_language, len(combined_text))
|
279 |
+
|
280 |
+
update_result = collection.update_one({"call_id": call_id}, {"$set": {f"key_terms.{user_id}": key_terms}})
|
281 |
+
|
282 |
+
if update_result.modified_count == 0:
|
283 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Call not updated!")
|
284 |
+
|
285 |
+
return key_terms
|
backend/mongodb/operations/users.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import Body, Request, HTTPException, status
|
2 |
+
from fastapi.encoders import jsonable_encoder
|
3 |
+
import sys
|
4 |
+
from ..models.users import User, UpdateUser
|
5 |
+
from bson import ObjectId
|
6 |
+
import re
|
7 |
+
|
8 |
+
|
9 |
+
def create_user(collection, user: User = Body(...)):
|
10 |
+
user = jsonable_encoder(user)
|
11 |
+
new_user = collection.insert_one(user)
|
12 |
+
created_user = collection.find_one({"_id": new_user.inserted_id})
|
13 |
+
print("NEW ID IS:.........", new_user.inserted_id)
|
14 |
+
return created_user
|
15 |
+
|
16 |
+
|
17 |
+
def list_users(collection, limit: int):
|
18 |
+
try:
|
19 |
+
users = list(collection.find(limit = limit))
|
20 |
+
return users
|
21 |
+
except:
|
22 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"No users found!")
|
23 |
+
|
24 |
+
|
25 |
+
def find_user(collection, user_id: str):
|
26 |
+
if (user := collection.find_one({"user_id": user_id})):
|
27 |
+
return user
|
28 |
+
else:
|
29 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with user_id {user_id} not found!")
|
30 |
+
|
31 |
+
def find_name_from_id(collection, user_id: str):
|
32 |
+
|
33 |
+
# find_one user record based on user id and project for user name
|
34 |
+
if (user_name := collection.find_one({"user_id": user_id}, {"name": 1, "_id": 0})):
|
35 |
+
return user_name['name'] # index name field from single field record returned
|
36 |
+
else:
|
37 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with user_id {user_id} not found!")
|
38 |
+
|
39 |
+
def find_user_name(collection, name: str):
|
40 |
+
# search for name in lowercase
|
41 |
+
if (user := collection.find_one({"name": re.compile('^' + re.escape(name) + '$', re.IGNORECASE)})):
|
42 |
+
return user
|
43 |
+
else:
|
44 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with name {name} not found!")
|
45 |
+
|
46 |
+
|
47 |
+
def find_user_email(collection, email: str):
|
48 |
+
if (user := collection.find_one({"email": re.compile('^' + re.escape(email) + '$', re.IGNORECASE)})):
|
49 |
+
return user
|
50 |
+
else:
|
51 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with Email Address {email} not found!")
|
52 |
+
|
53 |
+
|
54 |
+
''' Update user record based on user object/json'''
|
55 |
+
def update_user(collection, user_id: str, user: UpdateUser):
|
56 |
+
try:
|
57 |
+
user = {k: v for k, v in user.model_dump().items() if v is not None}
|
58 |
+
if len(user) >= 1:
|
59 |
+
update_result = collection.update_one({"user_id": user_id}, {"$set": user})
|
60 |
+
|
61 |
+
if update_result.modified_count == 0:
|
62 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with user_id: '{user_id}' not found and updated!")
|
63 |
+
|
64 |
+
if (existing_users := collection.find_one({"user_id": user_id})) is not None:
|
65 |
+
return existing_users
|
66 |
+
except:
|
67 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with user_id: '{user_id}' not found and updated!")
|
68 |
+
|
69 |
+
|
70 |
+
def delete_user(collection, user_id: str):
|
71 |
+
try:
|
72 |
+
deleted_user = collection.delete_one({"user_id": user_id})
|
73 |
+
|
74 |
+
if deleted_user.deleted_count == 1:
|
75 |
+
return f"User with user_id {user_id} deleted sucessfully"
|
76 |
+
except:
|
77 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with user_id {user_id} not found!")
|
backend/pcmToWav.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wave
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
basePath = os.path.expanduser("~/Desktop/")
|
6 |
+
|
7 |
+
|
8 |
+
def convert_pcm_to_wav():
|
9 |
+
# PCM file parameters (should match the parameters used to create the PCM file)
|
10 |
+
pcm_file = basePath + 'output.pcm'
|
11 |
+
wav_file = 'pcmconverted.wav'
|
12 |
+
sample_rate = 16000 # Example: 16000 Hz
|
13 |
+
channels = 1 # Example: 2 for stereo
|
14 |
+
sample_width = 2 # Example: 2 bytes (16 bits), change if your PCM format is different
|
15 |
+
|
16 |
+
# Read the PCM file and write to a WAV file
|
17 |
+
with open(pcm_file, 'rb') as pcmfile:
|
18 |
+
pcm_data = pcmfile.read()
|
19 |
+
|
20 |
+
with wave.open(wav_file, 'wb') as wavfile:
|
21 |
+
wavfile.setnchannels(channels)
|
22 |
+
wavfile.setsampwidth(sample_width)
|
23 |
+
wavfile.setframerate(sample_rate)
|
24 |
+
wavfile.writeframes(pcm_data)
|
25 |
+
|
26 |
+
convert_pcm_to_wav()
|
27 |
+
|
28 |
+
# def generateCaptions(filepath):
|
29 |
+
|
30 |
+
# ! This might be redundant due to seamless-streaming
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
print(f"Converted {pcm_file} to {wav_file}")
|
backend/preprocess_wav.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import soundfile
|
2 |
+
import io
|
3 |
+
from typing import Any, Tuple, Union, Optional
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
def preprocess_wav(data: Any, incoming_sample_rate) -> Tuple[np.ndarray, int]:
|
8 |
+
segment, sample_rate = soundfile.read(
|
9 |
+
io.BytesIO(data),
|
10 |
+
dtype="float32",
|
11 |
+
always_2d=True,
|
12 |
+
frames=-1,
|
13 |
+
start=0,
|
14 |
+
format="RAW",
|
15 |
+
subtype="PCM_16",
|
16 |
+
samplerate=incoming_sample_rate,
|
17 |
+
channels=1,
|
18 |
+
)
|
19 |
+
return segment, sample_rate
|
20 |
+
|
21 |
+
def convert_waveform(
|
22 |
+
waveform: Union[np.ndarray, torch.Tensor],
|
23 |
+
sample_rate: int,
|
24 |
+
normalize_volume: bool = False,
|
25 |
+
to_mono: bool = False,
|
26 |
+
to_sample_rate: Optional[int] = None,
|
27 |
+
) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
|
28 |
+
"""convert a waveform:
|
29 |
+
- to a target sample rate
|
30 |
+
- from multi-channel to mono channel
|
31 |
+
- volume normalization
|
32 |
+
|
33 |
+
Args:
|
34 |
+
waveform (numpy.ndarray or torch.Tensor): 2D original waveform
|
35 |
+
(channels x length)
|
36 |
+
sample_rate (int): original sample rate
|
37 |
+
normalize_volume (bool): perform volume normalization
|
38 |
+
to_mono (bool): convert to mono channel if having multiple channels
|
39 |
+
to_sample_rate (Optional[int]): target sample rate
|
40 |
+
Returns:
|
41 |
+
waveform (numpy.ndarray): converted 2D waveform (channels x length)
|
42 |
+
sample_rate (float): target sample rate
|
43 |
+
"""
|
44 |
+
try:
|
45 |
+
import torchaudio.sox_effects as ta_sox
|
46 |
+
except ImportError:
|
47 |
+
raise ImportError("Please install torchaudio: pip install torchaudio")
|
48 |
+
|
49 |
+
effects = []
|
50 |
+
if normalize_volume:
|
51 |
+
effects.append(["gain", "-n"])
|
52 |
+
if to_sample_rate is not None and to_sample_rate != sample_rate:
|
53 |
+
effects.append(["rate", f"{to_sample_rate}"])
|
54 |
+
if to_mono and waveform.shape[0] > 1:
|
55 |
+
effects.append(["channels", "1"])
|
56 |
+
if len(effects) > 0:
|
57 |
+
is_np_input = isinstance(waveform, np.ndarray)
|
58 |
+
_waveform = torch.from_numpy(waveform) if is_np_input else waveform
|
59 |
+
converted, converted_sample_rate = ta_sox.apply_effects_tensor(
|
60 |
+
_waveform, sample_rate, effects
|
61 |
+
)
|
62 |
+
if is_np_input:
|
63 |
+
converted = converted.numpy()
|
64 |
+
return converted, converted_sample_rate
|
65 |
+
return waveform, sample_rate
|
backend/requirements.txt
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
colorlog==6.8.2
|
2 |
+
contextlib2==21.6.0
|
3 |
+
fastapi==0.110.1
|
4 |
+
g2p_en==2.1.0
|
5 |
+
matplotlib==3.7.0
|
6 |
+
numpy==1.24.2
|
7 |
+
openai==1.20.0
|
8 |
+
protobuf==5.26.1
|
9 |
+
pydantic==2.7.0
|
10 |
+
pydub==0.25.1
|
11 |
+
pymongo==4.6.2
|
12 |
+
PySoundFile==0.9.0.post1
|
13 |
+
python-dotenv==1.0.1
|
14 |
+
python-socketio==5.9.0
|
15 |
+
pymongo==4.6.2
|
16 |
+
Requests==2.31.0
|
17 |
+
sentencepiece==0.1.99
|
18 |
+
simuleval==1.1.4
|
19 |
+
soundfile==0.12.1
|
20 |
+
spacy==3.7.4
|
21 |
+
pytextrank==3.3.0
|
22 |
+
torch==2.1.2
|
23 |
+
torchaudio==2.1.2
|
24 |
+
#transformers==4.20.1
|
25 |
+
uvicorn==0.29.0
|
26 |
+
vad==1.0.2
|
27 |
+
hf_transfer==0.1.4
|
28 |
+
huggingface_hub==0.19.4
|
backend/routes/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from.routing import router
|
backend/routes/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (235 Bytes). View file
|
|
backend/routes/__pycache__/routing.cpython-310.pyc
ADDED
Binary file (375 Bytes). View file
|
|
backend/routes/routing.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter
|
2 |
+
import sys
|
3 |
+
# sys.path.append('/Users/benolojo/DCU/CA4/ca400_FinalYearProject/2024-ca400-olojob2-majdap2/src/backend/src/')
|
4 |
+
from mongodb.endpoints import users, calls
|
5 |
+
|
6 |
+
router = APIRouter()
|
7 |
+
router.include_router(calls.router)
|
8 |
+
router.include_router(users.router)
|
9 |
+
# router.include_router(transcripts.router)
|
backend/seamless/__init__.py
ADDED
File without changes
|
backend/seamless/room.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import json
|
2 |
+
import uuid
|
3 |
+
|
4 |
+
|
5 |
+
class Room:
|
6 |
+
def __init__(self, room_id) -> None:
|
7 |
+
self.room_id = room_id
|
8 |
+
# members is a dict from client_id to Member
|
9 |
+
self.members = {}
|
10 |
+
|
11 |
+
# listeners and speakers are lists of client_id's
|
12 |
+
self.listeners = []
|
13 |
+
self.speakers = []
|
14 |
+
|
15 |
+
def __str__(self) -> str:
|
16 |
+
return f"Room {self.room_id} ({len(self.members)} member{'s' if len(self.members) == 1 else ''})"
|
17 |
+
|
18 |
+
def to_json(self):
|
19 |
+
varsResult = vars(self)
|
20 |
+
# Remember: result is just a shallow copy, so result.members === self.members
|
21 |
+
# Because of that, we need to jsonify self.members without writing over result.members,
|
22 |
+
# which we do here via dictionary unpacking (the ** operator)
|
23 |
+
result = {
|
24 |
+
**varsResult,
|
25 |
+
"members": {key: value.to_json() for (key, value) in self.members.items()},
|
26 |
+
"activeTranscoders": self.get_active_transcoders(),
|
27 |
+
}
|
28 |
+
|
29 |
+
return result
|
30 |
+
|
31 |
+
def get_active_connections(self):
|
32 |
+
return len(
|
33 |
+
[m for m in self.members.values() if m.connection_status == "connected"]
|
34 |
+
)
|
35 |
+
|
36 |
+
def get_active_transcoders(self):
|
37 |
+
return len([m for m in self.members.values() if m.transcoder is not None])
|
38 |
+
|
39 |
+
def get_room_status_dict(self):
|
40 |
+
return {
|
41 |
+
"activeConnections": self.get_active_connections(),
|
42 |
+
"activeTranscoders": self.get_active_transcoders(),
|
43 |
+
}
|
44 |
+
|
45 |
+
|
46 |
+
class Member:
|
47 |
+
def __init__(self, client_id, session_id, name) -> None:
|
48 |
+
self.client_id = client_id
|
49 |
+
self.session_id = session_id
|
50 |
+
self.name = name
|
51 |
+
self.connection_status = "connected"
|
52 |
+
self.transcoder = None
|
53 |
+
self.requested_output_type = None
|
54 |
+
self.transcoder_dynamic_config = None
|
55 |
+
|
56 |
+
def __str__(self) -> str:
|
57 |
+
return f"{self.name} (id: {self.client_id[:4]}...) ({self.connection_status})"
|
58 |
+
|
59 |
+
def to_json(self):
|
60 |
+
self_vars = vars(self)
|
61 |
+
return {
|
62 |
+
**self_vars,
|
63 |
+
"transcoder": self.transcoder is not None,
|
64 |
+
}
|
backend/seamless/simuleval_agent_directory.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Creates a directory in which to look up available agents
|
2 |
+
|
3 |
+
import os
|
4 |
+
from typing import List, Optional
|
5 |
+
from seamless.simuleval_transcoder import SimulevalTranscoder
|
6 |
+
import json
|
7 |
+
import logging
|
8 |
+
|
9 |
+
logger = logging.getLogger("gunicorn")
|
10 |
+
|
11 |
+
# fmt: off
|
12 |
+
M4T_P0_LANGS = [
|
13 |
+
"eng",
|
14 |
+
"arb", "ben", "cat", "ces", "cmn", "cym", "dan",
|
15 |
+
"deu", "est", "fin", "fra", "hin", "ind", "ita",
|
16 |
+
"jpn", "kor", "mlt", "nld", "pes", "pol", "por",
|
17 |
+
"ron", "rus", "slk", "spa", "swe", "swh", "tel",
|
18 |
+
"tgl", "tha", "tur", "ukr", "urd", "uzn", "vie",
|
19 |
+
]
|
20 |
+
# fmt: on
|
21 |
+
|
22 |
+
|
23 |
+
class NoAvailableAgentException(Exception):
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
class AgentWithInfo:
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
agent,
|
31 |
+
name: str,
|
32 |
+
modalities: List[str],
|
33 |
+
target_langs: List[str],
|
34 |
+
# Supported dynamic params are defined in StreamingTypes.ts
|
35 |
+
dynamic_params: List[str] = [],
|
36 |
+
description="",
|
37 |
+
has_expressive: Optional[bool] = None,
|
38 |
+
):
|
39 |
+
self.agent = agent
|
40 |
+
self.has_expressive = has_expressive
|
41 |
+
self.name = name
|
42 |
+
self.description = description
|
43 |
+
self.modalities = modalities
|
44 |
+
self.target_langs = target_langs
|
45 |
+
self.dynamic_params = dynamic_params
|
46 |
+
|
47 |
+
def get_capabilities_for_json(self):
|
48 |
+
return {
|
49 |
+
"name": self.name,
|
50 |
+
"description": self.description,
|
51 |
+
"modalities": self.modalities,
|
52 |
+
"targetLangs": self.target_langs,
|
53 |
+
"dynamicParams": self.dynamic_params,
|
54 |
+
}
|
55 |
+
|
56 |
+
@classmethod
|
57 |
+
def load_from_json(cls, config: str):
|
58 |
+
"""
|
59 |
+
Takes in JSON array of models to load in, e.g.
|
60 |
+
[{"name": "s2s_m4t_emma-unity2_multidomain_v0.1", "description": "M4T model that supports simultaneous S2S and S2T", "modalities": ["s2t", "s2s"], "targetLangs": ["en"]},
|
61 |
+
{"name": "s2s_m4t_expr-emma_v0.1", "description": "ES-EN expressive model that supports S2S and S2T", "modalities": ["s2t", "s2s"], "targetLangs": ["en"]}]
|
62 |
+
"""
|
63 |
+
configs = json.loads(config)
|
64 |
+
agents = []
|
65 |
+
for config in configs:
|
66 |
+
agent = SimulevalTranscoder.build_agent(config["name"])
|
67 |
+
agents.append(
|
68 |
+
AgentWithInfo(
|
69 |
+
agent=agent,
|
70 |
+
name=config["name"],
|
71 |
+
modalities=config["modalities"],
|
72 |
+
target_langs=config["targetLangs"],
|
73 |
+
)
|
74 |
+
)
|
75 |
+
return agents
|
76 |
+
|
77 |
+
|
78 |
+
class SimulevalAgentDirectory:
|
79 |
+
# Available models. These are the directories where the models can be found, and also serve as an ID for the model.
|
80 |
+
seamless_streaming_agent = "SeamlessStreaming"
|
81 |
+
seamless_agent = "Seamless"
|
82 |
+
|
83 |
+
def __init__(self):
|
84 |
+
self.agents = []
|
85 |
+
self.did_build_and_add_agents = False
|
86 |
+
|
87 |
+
def add_agent(self, agent: AgentWithInfo):
|
88 |
+
self.agents.append(agent)
|
89 |
+
|
90 |
+
def build_agent_if_available(self, model_id, config_name=None):
|
91 |
+
agent = None
|
92 |
+
try:
|
93 |
+
if config_name is not None:
|
94 |
+
agent = SimulevalTranscoder.build_agent(
|
95 |
+
model_id,
|
96 |
+
config_name=config_name,
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
agent = SimulevalTranscoder.build_agent(
|
100 |
+
model_id,
|
101 |
+
)
|
102 |
+
except Exception as e:
|
103 |
+
from fairseq2.assets.error import AssetError
|
104 |
+
logger.warning("Failed to build agent %s: %s" % (model_id, e))
|
105 |
+
if isinstance(e, AssetError):
|
106 |
+
logger.warning(
|
107 |
+
"Please download gated assets and set `gated_model_dir` in the config"
|
108 |
+
)
|
109 |
+
raise e
|
110 |
+
|
111 |
+
return agent
|
112 |
+
|
113 |
+
def build_and_add_agents(self, models_override=None):
|
114 |
+
if self.did_build_and_add_agents:
|
115 |
+
return
|
116 |
+
|
117 |
+
if models_override is not None:
|
118 |
+
agent_infos = AgentWithInfo.load_from_json(models_override)
|
119 |
+
for agent_info in agent_infos:
|
120 |
+
self.add_agent(agent_info)
|
121 |
+
else:
|
122 |
+
s2s_agent = None
|
123 |
+
if os.environ.get("USE_EXPRESSIVE_MODEL", "0") == "1":
|
124 |
+
logger.info("Building expressive model...")
|
125 |
+
s2s_agent = self.build_agent_if_available(
|
126 |
+
SimulevalAgentDirectory.seamless_agent,
|
127 |
+
config_name="vad_s2st_sc_24khz_main.yaml",
|
128 |
+
)
|
129 |
+
has_expressive = True
|
130 |
+
else:
|
131 |
+
logger.info("Building non-expressive model...")
|
132 |
+
s2s_agent = self.build_agent_if_available(
|
133 |
+
SimulevalAgentDirectory.seamless_streaming_agent,
|
134 |
+
config_name="vad_s2st_sc_main.yaml",
|
135 |
+
)
|
136 |
+
has_expressive = False
|
137 |
+
|
138 |
+
if s2s_agent:
|
139 |
+
self.add_agent(
|
140 |
+
AgentWithInfo(
|
141 |
+
agent=s2s_agent,
|
142 |
+
name=SimulevalAgentDirectory.seamless_streaming_agent,
|
143 |
+
modalities=["s2t", "s2s"],
|
144 |
+
target_langs=M4T_P0_LANGS,
|
145 |
+
dynamic_params=["expressive"],
|
146 |
+
description="multilingual expressive model that supports S2S and S2T",
|
147 |
+
has_expressive=has_expressive,
|
148 |
+
)
|
149 |
+
)
|
150 |
+
|
151 |
+
if len(self.agents) == 0:
|
152 |
+
logger.error(
|
153 |
+
"No agents were loaded. This likely means you are missing the actual model files specified in simuleval_agent_directory."
|
154 |
+
)
|
155 |
+
|
156 |
+
self.did_build_and_add_agents = True
|
157 |
+
|
158 |
+
def get_agent(self, name):
|
159 |
+
for agent in self.agents:
|
160 |
+
if agent.name == name:
|
161 |
+
return agent
|
162 |
+
return None
|
163 |
+
|
164 |
+
def get_agent_or_throw(self, name):
|
165 |
+
agent = self.get_agent(name)
|
166 |
+
if agent is None:
|
167 |
+
raise NoAvailableAgentException("No agent found with name= %s" % (name))
|
168 |
+
return agent
|
169 |
+
|
170 |
+
def get_agents_capabilities_list_for_json(self):
|
171 |
+
return [agent.get_capabilities_for_json() for agent in self.agents]
|
backend/seamless/simuleval_transcoder.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from simuleval.utils.agent import build_system_from_dir
|
2 |
+
from typing import Any, List, Optional, Tuple, Union
|
3 |
+
import numpy as np
|
4 |
+
import soundfile
|
5 |
+
import io
|
6 |
+
import asyncio
|
7 |
+
from simuleval.agents.pipeline import TreeAgentPipeline
|
8 |
+
from simuleval.agents.states import AgentStates
|
9 |
+
from simuleval.data.segments import Segment, EmptySegment, SpeechSegment
|
10 |
+
import threading
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
import sys
|
14 |
+
from pathlib import Path
|
15 |
+
import time
|
16 |
+
from g2p_en import G2p
|
17 |
+
import torch
|
18 |
+
import traceback
|
19 |
+
import time
|
20 |
+
import random
|
21 |
+
import colorlog
|
22 |
+
|
23 |
+
from .speech_and_text_output import SpeechAndTextOutput
|
24 |
+
|
25 |
+
MODEL_SAMPLE_RATE = 16_000
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
# logger.propagate = False
|
29 |
+
handler = colorlog.StreamHandler(stream=sys.stdout)
|
30 |
+
formatter = colorlog.ColoredFormatter(
|
31 |
+
"%(log_color)s[%(asctime)s][%(levelname)s][%(module)s]:%(reset)s %(message)s",
|
32 |
+
reset=True,
|
33 |
+
log_colors={
|
34 |
+
"DEBUG": "cyan",
|
35 |
+
"INFO": "green",
|
36 |
+
"WARNING": "yellow",
|
37 |
+
"ERROR": "red",
|
38 |
+
"CRITICAL": "red,bg_white",
|
39 |
+
},
|
40 |
+
)
|
41 |
+
handler.setFormatter(formatter)
|
42 |
+
logger.addHandler(handler)
|
43 |
+
logger.setLevel(logging.WARNING)
|
44 |
+
|
45 |
+
|
46 |
+
class OutputSegments:
|
47 |
+
def __init__(self, segments: Union[List[Segment], Segment]):
|
48 |
+
if isinstance(segments, Segment):
|
49 |
+
segments = [segments]
|
50 |
+
self.segments: List[Segment] = [s for s in segments]
|
51 |
+
|
52 |
+
@property
|
53 |
+
def is_empty(self):
|
54 |
+
return all(segment.is_empty for segment in self.segments)
|
55 |
+
|
56 |
+
@property
|
57 |
+
def finished(self):
|
58 |
+
return all(segment.finished for segment in self.segments)
|
59 |
+
|
60 |
+
def compute_length(self, g2p):
|
61 |
+
lengths = []
|
62 |
+
for segment in self.segments:
|
63 |
+
if segment.data_type == "text":
|
64 |
+
lengths.append(len([x for x in g2p(segment.content) if x != " "]))
|
65 |
+
elif segment.data_type == "speech":
|
66 |
+
lengths.append(len(segment.content) / MODEL_SAMPLE_RATE)
|
67 |
+
elif isinstance(segment, EmptySegment):
|
68 |
+
continue
|
69 |
+
else:
|
70 |
+
logger.warning(
|
71 |
+
f"Unexpected data_type: {segment.data_type} not in 'speech', 'text'"
|
72 |
+
)
|
73 |
+
return max(lengths)
|
74 |
+
|
75 |
+
@classmethod
|
76 |
+
def join_output_buffer(
|
77 |
+
cls, buffer: List[List[Segment]], output: SpeechAndTextOutput
|
78 |
+
):
|
79 |
+
num_segments = len(buffer[0])
|
80 |
+
for i in range(num_segments):
|
81 |
+
segment_list = [
|
82 |
+
buffer[j][i]
|
83 |
+
for j in range(len(buffer))
|
84 |
+
if buffer[j][i].data_type is not None
|
85 |
+
]
|
86 |
+
if len(segment_list) == 0:
|
87 |
+
continue
|
88 |
+
if len(set(segment.data_type for segment in segment_list)) != 1:
|
89 |
+
logger.warning(
|
90 |
+
f"Data type mismatch at {i}: {set(segment.data_type for segment in segment_list)}"
|
91 |
+
)
|
92 |
+
continue
|
93 |
+
data_type = segment_list[0].data_type
|
94 |
+
if data_type == "text":
|
95 |
+
if output.text is not None:
|
96 |
+
logger.warning("Multiple text outputs, overwriting!")
|
97 |
+
output.text = " ".join([segment.content for segment in segment_list])
|
98 |
+
elif data_type == "speech":
|
99 |
+
if output.speech_samples is not None:
|
100 |
+
logger.warning("Multiple speech outputs, overwriting!")
|
101 |
+
speech_out = []
|
102 |
+
for segment in segment_list:
|
103 |
+
speech_out += segment.content
|
104 |
+
output.speech_samples = speech_out
|
105 |
+
output.speech_sample_rate = segment.sample_rate
|
106 |
+
elif isinstance(segment_list[0], EmptySegment):
|
107 |
+
continue
|
108 |
+
else:
|
109 |
+
logger.warning(
|
110 |
+
f"Invalid output buffer data type: {data_type}, expected 'speech' or 'text"
|
111 |
+
)
|
112 |
+
|
113 |
+
return output
|
114 |
+
|
115 |
+
def __repr__(self) -> str:
|
116 |
+
repr_str = str(self.segments)
|
117 |
+
return f"{self.__class__.__name__}(\n\t{repr_str}\n)"
|
118 |
+
|
119 |
+
|
120 |
+
class SimulevalTranscoder:
|
121 |
+
def __init__(self, agent, sample_rate, debug, buffer_limit):
|
122 |
+
self.agent = agent.agent
|
123 |
+
self.has_expressive = agent.has_expressive
|
124 |
+
self.input_queue = asyncio.Queue()
|
125 |
+
self.output_queue = asyncio.Queue()
|
126 |
+
self.states = self.agent.build_states()
|
127 |
+
if debug:
|
128 |
+
self.get_states_root().debug = True
|
129 |
+
self.incoming_sample_rate = sample_rate
|
130 |
+
self.close = False
|
131 |
+
self.g2p = G2p()
|
132 |
+
|
133 |
+
# buffer all outgoing translations within this amount of time
|
134 |
+
self.output_buffer_idle_ms = 5000
|
135 |
+
self.output_buffer_size_limit = (
|
136 |
+
buffer_limit # phonemes for text, seconds for speech
|
137 |
+
)
|
138 |
+
self.output_buffer_cur_size = 0
|
139 |
+
self.output_buffer: List[List[Segment]] = []
|
140 |
+
self.speech_output_sample_rate = None
|
141 |
+
|
142 |
+
self.last_output_ts = time.time() * 1000
|
143 |
+
self.timeout_ms = (
|
144 |
+
30000 # close the transcoder thread after this amount of silence
|
145 |
+
)
|
146 |
+
self.first_input_ts = None
|
147 |
+
self.first_output_ts = None
|
148 |
+
self.debug = debug
|
149 |
+
self.debug_ts = f"{time.time()}_{random.randint(1000, 9999)}"
|
150 |
+
if self.debug:
|
151 |
+
debug_folder = Path(__file__).resolve().parent.parent / "debug"
|
152 |
+
self.test_incoming_wav = soundfile.SoundFile(
|
153 |
+
debug_folder / f"{self.debug_ts}_test_incoming.wav",
|
154 |
+
mode="w+",
|
155 |
+
format="WAV",
|
156 |
+
subtype="PCM_16",
|
157 |
+
samplerate=self.incoming_sample_rate,
|
158 |
+
channels=1,
|
159 |
+
)
|
160 |
+
self.get_states_root().test_input_segments_wav = soundfile.SoundFile(
|
161 |
+
debug_folder / f"{self.debug_ts}_test_input_segments.wav",
|
162 |
+
mode="w+",
|
163 |
+
format="WAV",
|
164 |
+
samplerate=MODEL_SAMPLE_RATE,
|
165 |
+
channels=1,
|
166 |
+
)
|
167 |
+
|
168 |
+
def get_states_root(self) -> AgentStates:
|
169 |
+
if isinstance(self.agent, TreeAgentPipeline):
|
170 |
+
# self.states is a dict
|
171 |
+
return self.states[self.agent.source_module]
|
172 |
+
else:
|
173 |
+
# self.states is a list
|
174 |
+
return self.states[0]
|
175 |
+
|
176 |
+
def reset_states(self):
|
177 |
+
if isinstance(self.agent, TreeAgentPipeline):
|
178 |
+
states_iter = self.states.values()
|
179 |
+
else:
|
180 |
+
states_iter = self.states
|
181 |
+
for state in states_iter:
|
182 |
+
state.reset()
|
183 |
+
|
184 |
+
def debug_log(self, *args):
|
185 |
+
if self.debug:
|
186 |
+
logger.info(*args)
|
187 |
+
|
188 |
+
@classmethod
|
189 |
+
def build_agent(cls, model_path, config_name):
|
190 |
+
logger.info(f"Building simuleval agent: {model_path}, {config_name}")
|
191 |
+
agent = build_system_from_dir(
|
192 |
+
Path(__file__).resolve().parent.parent / f"models/{model_path}",
|
193 |
+
config_name=config_name,
|
194 |
+
)
|
195 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
196 |
+
agent.to(device, fp16=True)
|
197 |
+
logger.info(
|
198 |
+
f"Successfully built simuleval agent {model_path} on device {device}"
|
199 |
+
)
|
200 |
+
|
201 |
+
return agent
|
202 |
+
|
203 |
+
def process_incoming_bytes(self, incoming_bytes, dynamic_config):
|
204 |
+
# TODO: We probably want to do some validation on dynamic_config to ensure it has what we needs
|
205 |
+
segment, sr = self._preprocess_wav(incoming_bytes)
|
206 |
+
segment = SpeechSegment(
|
207 |
+
content=segment,
|
208 |
+
sample_rate=sr,
|
209 |
+
tgt_lang=dynamic_config.get("targetLanguage"),
|
210 |
+
config=dynamic_config,
|
211 |
+
)
|
212 |
+
if dynamic_config.get("expressive") is True and self.has_expressive is False:
|
213 |
+
logger.warning(
|
214 |
+
"Passing 'expressive' but the agent does not support expressive output!"
|
215 |
+
)
|
216 |
+
# # segment is array([0, 0, 0, ..., 0, 0, 0], dtype=int16)
|
217 |
+
self.input_queue.put_nowait(segment)
|
218 |
+
|
219 |
+
def get_input_segment(self):
|
220 |
+
if self.input_queue.empty():
|
221 |
+
return None
|
222 |
+
chunk = self.input_queue.get_nowait()
|
223 |
+
self.input_queue.task_done()
|
224 |
+
return chunk
|
225 |
+
|
226 |
+
def convert_waveform(
|
227 |
+
self,
|
228 |
+
waveform: Union[np.ndarray, torch.Tensor],
|
229 |
+
sample_rate: int,
|
230 |
+
normalize_volume: bool = False,
|
231 |
+
to_mono: bool = False,
|
232 |
+
to_sample_rate: Optional[int] = None,
|
233 |
+
) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
|
234 |
+
"""convert a waveform:
|
235 |
+
- to a target sample rate
|
236 |
+
- from multi-channel to mono channel
|
237 |
+
- volume normalization
|
238 |
+
|
239 |
+
Args:
|
240 |
+
waveform (numpy.ndarray or torch.Tensor): 2D original waveform
|
241 |
+
(channels x length)
|
242 |
+
sample_rate (int): original sample rate
|
243 |
+
normalize_volume (bool): perform volume normalization
|
244 |
+
to_mono (bool): convert to mono channel if having multiple channels
|
245 |
+
to_sample_rate (Optional[int]): target sample rate
|
246 |
+
Returns:
|
247 |
+
waveform (numpy.ndarray): converted 2D waveform (channels x length)
|
248 |
+
sample_rate (float): target sample rate
|
249 |
+
"""
|
250 |
+
try:
|
251 |
+
import torchaudio.sox_effects as ta_sox
|
252 |
+
except ImportError:
|
253 |
+
raise ImportError("Please install torchaudio: pip install torchaudio")
|
254 |
+
|
255 |
+
effects = []
|
256 |
+
if normalize_volume:
|
257 |
+
effects.append(["gain", "-n"])
|
258 |
+
if to_sample_rate is not None and to_sample_rate != sample_rate:
|
259 |
+
effects.append(["rate", f"{to_sample_rate}"])
|
260 |
+
if to_mono and waveform.shape[0] > 1:
|
261 |
+
effects.append(["channels", "1"])
|
262 |
+
if len(effects) > 0:
|
263 |
+
is_np_input = isinstance(waveform, np.ndarray)
|
264 |
+
_waveform = torch.from_numpy(waveform) if is_np_input else waveform
|
265 |
+
converted, converted_sample_rate = ta_sox.apply_effects_tensor(
|
266 |
+
_waveform, sample_rate, effects
|
267 |
+
)
|
268 |
+
if is_np_input:
|
269 |
+
converted = converted.numpy()
|
270 |
+
return converted, converted_sample_rate
|
271 |
+
return waveform, sample_rate
|
272 |
+
|
273 |
+
def _preprocess_wav(self, data: Any) -> Tuple[np.ndarray, int]:
|
274 |
+
segment, sample_rate = soundfile.read(
|
275 |
+
io.BytesIO(data),
|
276 |
+
dtype="float32",
|
277 |
+
always_2d=True,
|
278 |
+
frames=-1,
|
279 |
+
start=0,
|
280 |
+
format="RAW",
|
281 |
+
subtype="PCM_16",
|
282 |
+
samplerate=self.incoming_sample_rate,
|
283 |
+
channels=1,
|
284 |
+
)
|
285 |
+
if self.debug:
|
286 |
+
self.test_incoming_wav.seek(0, soundfile.SEEK_END)
|
287 |
+
self.test_incoming_wav.write(segment)
|
288 |
+
|
289 |
+
segment = segment.T
|
290 |
+
segment, new_sample_rate = self.convert_waveform(
|
291 |
+
segment,
|
292 |
+
sample_rate,
|
293 |
+
normalize_volume=False,
|
294 |
+
to_mono=True,
|
295 |
+
to_sample_rate=MODEL_SAMPLE_RATE,
|
296 |
+
)
|
297 |
+
|
298 |
+
assert MODEL_SAMPLE_RATE == new_sample_rate
|
299 |
+
segment = segment.squeeze(axis=0)
|
300 |
+
return segment, new_sample_rate
|
301 |
+
|
302 |
+
def process_pipeline_impl(self, input_segment):
|
303 |
+
try:
|
304 |
+
with torch.no_grad():
|
305 |
+
output_segment = OutputSegments(
|
306 |
+
self.agent.pushpop(input_segment, self.states)
|
307 |
+
)
|
308 |
+
if (
|
309 |
+
self.get_states_root().first_input_ts is not None
|
310 |
+
and self.first_input_ts is None
|
311 |
+
):
|
312 |
+
# TODO: this is hacky
|
313 |
+
self.first_input_ts = self.get_states_root().first_input_ts
|
314 |
+
|
315 |
+
if not output_segment.is_empty:
|
316 |
+
self.output_queue.put_nowait(output_segment)
|
317 |
+
|
318 |
+
if output_segment.finished:
|
319 |
+
self.debug_log("OUTPUT SEGMENT IS FINISHED. Resetting states.")
|
320 |
+
|
321 |
+
self.reset_states()
|
322 |
+
|
323 |
+
if self.debug:
|
324 |
+
# when we rebuild states, this value is reset to whatever
|
325 |
+
# is in the system dir config, which defaults debug=False.
|
326 |
+
self.get_states_root().debug = True
|
327 |
+
except Exception as e:
|
328 |
+
logger.error(f"Got exception while processing pipeline: {e}")
|
329 |
+
traceback.print_exc()
|
330 |
+
return input_segment
|
331 |
+
|
332 |
+
def process_pipeline_loop(self):
|
333 |
+
if self.close:
|
334 |
+
return # closes the thread
|
335 |
+
|
336 |
+
self.debug_log("processing_pipeline")
|
337 |
+
while not self.close:
|
338 |
+
input_segment = self.get_input_segment()
|
339 |
+
if input_segment is None:
|
340 |
+
if self.get_states_root().is_fresh_state: # TODO: this is hacky
|
341 |
+
time.sleep(0.3)
|
342 |
+
else:
|
343 |
+
time.sleep(0.03)
|
344 |
+
continue
|
345 |
+
self.process_pipeline_impl(input_segment)
|
346 |
+
self.debug_log("finished processing_pipeline")
|
347 |
+
|
348 |
+
def process_pipeline_once(self):
|
349 |
+
if self.close:
|
350 |
+
return
|
351 |
+
|
352 |
+
self.debug_log("processing pipeline once")
|
353 |
+
input_segment = self.get_input_segment()
|
354 |
+
if input_segment is None:
|
355 |
+
return
|
356 |
+
self.process_pipeline_impl(input_segment)
|
357 |
+
self.debug_log("finished processing_pipeline_once")
|
358 |
+
|
359 |
+
def get_output_segment(self):
|
360 |
+
if self.output_queue.empty():
|
361 |
+
return None
|
362 |
+
|
363 |
+
output_chunk = self.output_queue.get_nowait()
|
364 |
+
self.output_queue.task_done()
|
365 |
+
return output_chunk
|
366 |
+
|
367 |
+
def start(self):
|
368 |
+
self.debug_log("starting transcoder in a thread")
|
369 |
+
threading.Thread(target=self.process_pipeline_loop).start()
|
370 |
+
|
371 |
+
def first_translation_time(self):
|
372 |
+
return round((self.first_output_ts - self.first_input_ts) / 1000, 2)
|
373 |
+
|
374 |
+
def get_buffered_output(self) -> SpeechAndTextOutput:
|
375 |
+
now = time.time() * 1000
|
376 |
+
self.debug_log(f"get_buffered_output queue size: {self.output_queue.qsize()}")
|
377 |
+
while not self.output_queue.empty():
|
378 |
+
tmp_out = self.get_output_segment()
|
379 |
+
if tmp_out and tmp_out.compute_length(self.g2p) > 0:
|
380 |
+
if len(self.output_buffer) == 0:
|
381 |
+
self.last_output_ts = now
|
382 |
+
self._populate_output_buffer(tmp_out)
|
383 |
+
self._increment_output_buffer_size(tmp_out)
|
384 |
+
|
385 |
+
if tmp_out.finished:
|
386 |
+
self.debug_log("tmp_out.finished")
|
387 |
+
res = self._gather_output_buffer_data(final=True)
|
388 |
+
self.debug_log(f"gathered output data: {res}")
|
389 |
+
self.output_buffer = []
|
390 |
+
self.increment_output_buffer_size = 0
|
391 |
+
self.last_output_ts = now
|
392 |
+
self.first_output_ts = now
|
393 |
+
return res
|
394 |
+
else:
|
395 |
+
self.debug_log("tmp_out.compute_length is not > 0")
|
396 |
+
|
397 |
+
if len(self.output_buffer) > 0 and (
|
398 |
+
now - self.last_output_ts >= self.output_buffer_idle_ms
|
399 |
+
or self.output_buffer_cur_size >= self.output_buffer_size_limit
|
400 |
+
):
|
401 |
+
self.debug_log(
|
402 |
+
"[get_buffered_output] output_buffer is not empty. getting res to return."
|
403 |
+
)
|
404 |
+
self.last_output_ts = now
|
405 |
+
res = self._gather_output_buffer_data(final=False)
|
406 |
+
self.debug_log(f"gathered output data: {res}")
|
407 |
+
self.output_buffer = []
|
408 |
+
self.output_buffer_phoneme_count = 0
|
409 |
+
self.first_output_ts = now
|
410 |
+
return res
|
411 |
+
else:
|
412 |
+
self.debug_log("[get_buffered_output] output_buffer is empty...")
|
413 |
+
return None
|
414 |
+
|
415 |
+
def _gather_output_buffer_data(self, final):
|
416 |
+
output = SpeechAndTextOutput()
|
417 |
+
output.final = final
|
418 |
+
output = OutputSegments.join_output_buffer(self.output_buffer, output)
|
419 |
+
return output
|
420 |
+
|
421 |
+
def _increment_output_buffer_size(self, segment: OutputSegments):
|
422 |
+
self.output_buffer_cur_size += segment.compute_length(self.g2p)
|
423 |
+
|
424 |
+
def _populate_output_buffer(self, segment: OutputSegments):
|
425 |
+
self.output_buffer.append(segment.segments)
|
426 |
+
|
427 |
+
def _compute_phoneme_count(self, string: str) -> int:
|
428 |
+
return len([x for x in self.g2p(string) if x != " "])
|
backend/seamless/speech_and_text_output.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Provides a container to return both speech and text output from our model at the same time
|
2 |
+
|
3 |
+
|
4 |
+
class SpeechAndTextOutput:
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
text: str = None,
|
8 |
+
speech_samples: list = None,
|
9 |
+
speech_sample_rate: float = None,
|
10 |
+
final: bool = False,
|
11 |
+
):
|
12 |
+
self.text = text
|
13 |
+
self.speech_samples = speech_samples
|
14 |
+
self.speech_sample_rate = speech_sample_rate
|
15 |
+
self.final = final
|
backend/seamless/transcoder_helpers.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
logger = logging.getLogger("gunicorn")
|
4 |
+
|
5 |
+
|
6 |
+
def get_transcoder_output_events(transcoder) -> list:
|
7 |
+
speech_and_text_output = transcoder.get_buffered_output()
|
8 |
+
if speech_and_text_output is None:
|
9 |
+
logger.debug("No output from transcoder.get_buffered_output()")
|
10 |
+
return []
|
11 |
+
|
12 |
+
logger.debug(f"We DID get output from the transcoder! {speech_and_text_output}")
|
13 |
+
|
14 |
+
lat = None
|
15 |
+
|
16 |
+
events = []
|
17 |
+
|
18 |
+
if speech_and_text_output.speech_samples:
|
19 |
+
events.append(
|
20 |
+
{
|
21 |
+
"event": "translation_speech",
|
22 |
+
"payload": speech_and_text_output.speech_samples,
|
23 |
+
"sample_rate": speech_and_text_output.speech_sample_rate,
|
24 |
+
}
|
25 |
+
)
|
26 |
+
|
27 |
+
if speech_and_text_output.text:
|
28 |
+
events.append(
|
29 |
+
{
|
30 |
+
"event": "translation_text",
|
31 |
+
"payload": speech_and_text_output.text,
|
32 |
+
}
|
33 |
+
)
|
34 |
+
|
35 |
+
for e in events:
|
36 |
+
e["eos"] = speech_and_text_output.final
|
37 |
+
|
38 |
+
# if not latency_sent:
|
39 |
+
# lat = transcoder.first_translation_time()
|
40 |
+
# latency_sent = True
|
41 |
+
# to_send["latency"] = lat
|
42 |
+
|
43 |
+
return events
|
backend/seamless_utils.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# base seamless imports
|
3 |
+
# ---------------------------------
|
4 |
+
import io
|
5 |
+
import json
|
6 |
+
import matplotlib as mpl
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import mmap
|
9 |
+
import numpy as np
|
10 |
+
import soundfile
|
11 |
+
import torchaudio
|
12 |
+
import torch
|
13 |
+
from pydub import AudioSegment
|
14 |
+
# ---------------------------------
|
15 |
+
# seamless-streaming specific imports
|
16 |
+
# ---------------------------------
|
17 |
+
import math
|
18 |
+
from simuleval.data.segments import SpeechSegment, EmptySegment
|
19 |
+
from seamless_communication.streaming.agents.seamless_streaming_s2st import (
|
20 |
+
SeamlessStreamingS2STVADAgent,
|
21 |
+
)
|
22 |
+
|
23 |
+
from simuleval.utils.arguments import cli_argument_list
|
24 |
+
from simuleval import options
|
25 |
+
|
26 |
+
|
27 |
+
from typing import Union, List
|
28 |
+
from simuleval.data.segments import Segment, TextSegment
|
29 |
+
from simuleval.agents.pipeline import TreeAgentPipeline
|
30 |
+
from simuleval.agents.states import AgentStates
|
31 |
+
# ---------------------------------
|
32 |
+
# seamless setup
|
33 |
+
# source: https://colab.research.google.com/github/kauterry/seamless_communication/blob/main/Seamless_Tutorial.ipynb?
|
34 |
+
SAMPLE_RATE = 16000
|
35 |
+
|
36 |
+
# PM - THis class is used to simulate the audio frontend in the seamless streaming pipeline
|
37 |
+
# need to replace this with the actual audio frontend
|
38 |
+
# TODO: replacement class that takes in PCM-16 bytes and returns SpeechSegment
|
39 |
+
class AudioFrontEnd:
|
40 |
+
def __init__(self, wav_file, segment_size) -> None:
|
41 |
+
self.samples, self.sample_rate = soundfile.read(wav_file)
|
42 |
+
print(self.sample_rate, "sample rate")
|
43 |
+
assert self.sample_rate == SAMPLE_RATE
|
44 |
+
# print(len(self.samples), self.samples[:100])
|
45 |
+
self.samples = self.samples # .tolist()
|
46 |
+
self.segment_size = segment_size
|
47 |
+
self.step = 0
|
48 |
+
|
49 |
+
def send_segment(self):
|
50 |
+
"""
|
51 |
+
This is the front-end logic in simuleval instance.py
|
52 |
+
"""
|
53 |
+
|
54 |
+
num_samples = math.ceil(self.segment_size / 1000 * self.sample_rate)
|
55 |
+
|
56 |
+
if self.step < len(self.samples):
|
57 |
+
if self.step + num_samples >= len(self.samples):
|
58 |
+
samples = self.samples[self.step :]
|
59 |
+
is_finished = True
|
60 |
+
else:
|
61 |
+
samples = self.samples[self.step : self.step + num_samples]
|
62 |
+
is_finished = False
|
63 |
+
self.samples = self.samples[self.step:]
|
64 |
+
self.step = min(self.step + num_samples, len(self.samples))
|
65 |
+
segment = SpeechSegment(
|
66 |
+
content=samples,
|
67 |
+
sample_rate=self.sample_rate,
|
68 |
+
finished=is_finished,
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
# Finish reading this audio
|
72 |
+
segment = EmptySegment(
|
73 |
+
finished=True,
|
74 |
+
)
|
75 |
+
self.step = 0
|
76 |
+
self.samples = []
|
77 |
+
return segment
|
78 |
+
|
79 |
+
# samples = self.samples[:num_samples]
|
80 |
+
# self.samples = self.samples[num_samples:]
|
81 |
+
# segment = SpeechSegment(
|
82 |
+
# content=samples,
|
83 |
+
# sample_rate=self.sample_rate,
|
84 |
+
# finished=False,
|
85 |
+
# )
|
86 |
+
|
87 |
+
|
88 |
+
def add_segments(self, wav):
|
89 |
+
new_samples, _ = soundfile.read(wav)
|
90 |
+
self.samples = np.concatenate((self.samples, new_samples))
|
91 |
+
|
92 |
+
|
93 |
+
class OutputSegments:
|
94 |
+
def __init__(self, segments: Union[List[Segment], Segment]):
|
95 |
+
if isinstance(segments, Segment):
|
96 |
+
segments = [segments]
|
97 |
+
self.segments: List[Segment] = [s for s in segments]
|
98 |
+
|
99 |
+
@property
|
100 |
+
def is_empty(self):
|
101 |
+
return all(segment.is_empty for segment in self.segments)
|
102 |
+
|
103 |
+
@property
|
104 |
+
def finished(self):
|
105 |
+
return all(segment.finished for segment in self.segments)
|
106 |
+
|
107 |
+
|
108 |
+
def get_audiosegment(samples, sr):
|
109 |
+
b = io.BytesIO()
|
110 |
+
soundfile.write(b, samples, samplerate=sr, format="wav")
|
111 |
+
b.seek(0)
|
112 |
+
return AudioSegment.from_file(b)
|
113 |
+
|
114 |
+
|
115 |
+
def reset_states(system, states):
|
116 |
+
if isinstance(system, TreeAgentPipeline):
|
117 |
+
states_iter = states.values()
|
118 |
+
else:
|
119 |
+
states_iter = states
|
120 |
+
for state in states_iter:
|
121 |
+
state.reset()
|
122 |
+
|
123 |
+
|
124 |
+
def get_states_root(system, states) -> AgentStates:
|
125 |
+
if isinstance(system, TreeAgentPipeline):
|
126 |
+
# self.states is a dict
|
127 |
+
return states[system.source_module]
|
128 |
+
else:
|
129 |
+
# self.states is a list
|
130 |
+
return system.states[0]
|
131 |
+
|
132 |
+
|
133 |
+
def build_streaming_system(model_configs, agent_class):
|
134 |
+
parser = options.general_parser()
|
135 |
+
parser.add_argument("-f", "--f", help="a dummy argument to fool ipython", default="1")
|
136 |
+
|
137 |
+
agent_class.add_args(parser)
|
138 |
+
args, _ = parser.parse_known_args(cli_argument_list(model_configs))
|
139 |
+
system = agent_class.from_args(args)
|
140 |
+
return system
|
141 |
+
|
142 |
+
|
143 |
+
def run_streaming_inference(system, audio_frontend, system_states, tgt_lang):
|
144 |
+
# NOTE: Here for visualization, we calculate delays offset from audio
|
145 |
+
# *BEFORE* VAD segmentation.
|
146 |
+
# In contrast for SimulEval evaluation, we assume audios are pre-segmented,
|
147 |
+
# and Average Lagging, End Offset metrics are based on those pre-segmented audios.
|
148 |
+
# Thus, delays here are *NOT* comparable to SimulEval per-segment delays
|
149 |
+
delays = {"s2st": [], "s2tt": []}
|
150 |
+
prediction_lists = {"s2st": [], "s2tt": []}
|
151 |
+
speech_durations = []
|
152 |
+
curr_delay = 0
|
153 |
+
target_sample_rate = None
|
154 |
+
|
155 |
+
while True:
|
156 |
+
input_segment = audio_frontend.send_segment()
|
157 |
+
input_segment.tgt_lang = tgt_lang
|
158 |
+
curr_delay += len(input_segment.content) / SAMPLE_RATE * 1000
|
159 |
+
if input_segment.finished:
|
160 |
+
# a hack, we expect a real stream to end with silence
|
161 |
+
get_states_root(system, system_states).source_finished = True
|
162 |
+
# Translation happens here
|
163 |
+
if isinstance(input_segment, EmptySegment):
|
164 |
+
return None, None, None, None
|
165 |
+
output_segments = OutputSegments(system.pushpop(input_segment, system_states))
|
166 |
+
if not output_segments.is_empty:
|
167 |
+
for segment in output_segments.segments:
|
168 |
+
# NOTE: another difference from SimulEval evaluation -
|
169 |
+
# delays are accumulated per-token
|
170 |
+
if isinstance(segment, SpeechSegment):
|
171 |
+
pred_duration = 1000 * len(segment.content) / segment.sample_rate
|
172 |
+
speech_durations.append(pred_duration)
|
173 |
+
delays["s2st"].append(curr_delay)
|
174 |
+
prediction_lists["s2st"].append(segment.content)
|
175 |
+
target_sample_rate = segment.sample_rate
|
176 |
+
elif isinstance(segment, TextSegment):
|
177 |
+
delays["s2tt"].append(curr_delay)
|
178 |
+
prediction_lists["s2tt"].append(segment.content)
|
179 |
+
print(curr_delay, segment.content)
|
180 |
+
if output_segments.finished:
|
181 |
+
reset_states(system, system_states)
|
182 |
+
if input_segment.finished:
|
183 |
+
# an assumption of SimulEval agents -
|
184 |
+
# once source_finished=True, generate until output translation is finished
|
185 |
+
break
|
186 |
+
return delays, prediction_lists, speech_durations, target_sample_rate
|
187 |
+
|
188 |
+
|
189 |
+
def get_s2st_delayed_targets(delays, target_sample_rate, prediction_lists, speech_durations):
|
190 |
+
# get calculate intervals + durations for s2st
|
191 |
+
intervals = []
|
192 |
+
|
193 |
+
start = prev_end = prediction_offset = delays["s2st"][0]
|
194 |
+
target_samples = [0.0] * int(target_sample_rate * prediction_offset / 1000)
|
195 |
+
|
196 |
+
for i, delay in enumerate(delays["s2st"]):
|
197 |
+
start = max(prev_end, delay)
|
198 |
+
|
199 |
+
if start > prev_end:
|
200 |
+
# Wait source speech, add discontinuity with silence
|
201 |
+
target_samples += [0.0] * int(
|
202 |
+
target_sample_rate * (start - prev_end) / 1000
|
203 |
+
)
|
204 |
+
|
205 |
+
target_samples += prediction_lists["s2st"][i]
|
206 |
+
duration = speech_durations[i]
|
207 |
+
prev_end = start + duration
|
208 |
+
intervals.append([start, duration])
|
209 |
+
return target_samples, intervals
|
210 |
+
|
backend/tests/__pycache__/test_client.cpython-310-pytest-8.1.1.pyc
ADDED
Binary file (6.82 kB). View file
|
|
backend/tests/__pycache__/test_main.cpython-310-pytest-8.1.1.pyc
ADDED
Binary file (3.38 kB). View file
|
|
backend/tests/__pycache__/test_main.cpython-310.pyc
ADDED
Binary file (2.2 kB). View file
|
|
backend/tests/silence.wav
ADDED
Binary file (302 kB). View file
|
|
backend/tests/speaking.wav
ADDED
Binary file (255 kB). View file
|
|
backend/tests/test_client.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import wave
|
3 |
+
import pytest
|
4 |
+
import torchaudio
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
|
8 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
9 |
+
parent_dir = os.path.dirname(current_dir)
|
10 |
+
sys.path.append(parent_dir)
|
11 |
+
from Client import Client
|
12 |
+
|
13 |
+
|
14 |
+
@pytest.fixture
|
15 |
+
def mock_client():
|
16 |
+
client = Client("test_sid", "test_client_id", original_sr=44100)
|
17 |
+
return client
|
18 |
+
|
19 |
+
def test_client_init(mock_client):
|
20 |
+
assert mock_client.sid == "test_sid"
|
21 |
+
assert mock_client.client_id == "test_client_id"
|
22 |
+
assert mock_client.call_id == None
|
23 |
+
assert mock_client.buffer == bytearray()
|
24 |
+
assert mock_client.output_path == "test_sid_output_audio.wav"
|
25 |
+
assert mock_client.target_language == None
|
26 |
+
assert mock_client.original_sr == 44100
|
27 |
+
assert mock_client.vad.sample_rate == 16000
|
28 |
+
assert mock_client.vad.frame_length == 25
|
29 |
+
assert mock_client.vad.frame_shift == 20
|
30 |
+
assert mock_client.vad.energy_threshold == 0.05
|
31 |
+
assert mock_client.vad.pre_emphasis == 0.95
|
32 |
+
|
33 |
+
def test_client_add_bytes(mock_client):
|
34 |
+
mock_client.add_bytes(b"test")
|
35 |
+
assert mock_client.buffer == b"test"
|
36 |
+
|
37 |
+
def test_client_resample_and_clear(mock_client):
|
38 |
+
location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
39 |
+
speaking_bytes = wave.open(location + "/speaking.wav", "rb").readframes(-1)
|
40 |
+
mock_client.add_bytes(speaking_bytes)
|
41 |
+
resampled_waveform = mock_client.resample_and_clear()
|
42 |
+
torchaudio.save(location + "testoutput.wav", resampled_waveform, 16000)
|
43 |
+
with wave.open(location + "testoutput.wav", "rb") as wf:
|
44 |
+
sample_rate = wf.getframerate()
|
45 |
+
assert mock_client.buffer == bytearray()
|
46 |
+
assert sample_rate == 16000
|
47 |
+
|
48 |
+
def test_client_vad(mock_client):
|
49 |
+
location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
50 |
+
speaking_bytes = wave.open(location + "/speaking.wav", "rb").readframes(-1)
|
51 |
+
mock_client.add_bytes(speaking_bytes)
|
52 |
+
resampled_waveform = mock_client.resample_and_clear()
|
53 |
+
assert mock_client.buffer == bytearray()
|
54 |
+
assert mock_client.vad_analyse(resampled_waveform) == True
|
55 |
+
silent_bytes = wave.open(location + "/silence.wav", "rb").readframes(-1)
|
56 |
+
mock_client.add_bytes(silent_bytes)
|
57 |
+
resampled_waveform = mock_client.resample_and_clear()
|
58 |
+
assert mock_client.buffer == bytearray()
|
59 |
+
assert mock_client.vad_analyse(resampled_waveform) == False
|
backend/tests/test_main.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
import pytest
|
3 |
+
from unittest.mock import AsyncMock, MagicMock, ANY
|
4 |
+
import socketio
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
|
9 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
10 |
+
parent_dir = os.path.dirname(current_dir)
|
11 |
+
sys.path.append(parent_dir)
|
12 |
+
|
13 |
+
from Client import Client
|
14 |
+
from main import sio, connect, disconnect, target_language, call_user, answer_call, incoming_audio, clients, rooms
|
15 |
+
from unittest.mock import patch
|
16 |
+
|
17 |
+
sio = socketio.AsyncServer(
|
18 |
+
async_mode="asgi",
|
19 |
+
cors_allowed_origins="*",
|
20 |
+
# engineio_logger=logger,
|
21 |
+
)
|
22 |
+
# sio.logger.setLevel(logging.DEBUG)
|
23 |
+
socketio_app = socketio.ASGIApp(sio)
|
24 |
+
|
25 |
+
app = FastAPI()
|
26 |
+
app.mount("/", socketio_app)
|
27 |
+
|
28 |
+
@pytest.fixture(autouse=True)
|
29 |
+
def setup_clients_and_rooms():
|
30 |
+
global clients, rooms
|
31 |
+
clients.clear()
|
32 |
+
rooms.clear()
|
33 |
+
yield
|
34 |
+
|
35 |
+
@pytest.fixture
|
36 |
+
def mock_client():
|
37 |
+
client = Client("test_sid", "test_client_id", original_sr=44100)
|
38 |
+
return client
|
39 |
+
|
40 |
+
|
41 |
+
@pytest.mark.asyncio
|
42 |
+
async def test_connect(mock_client):
|
43 |
+
sid = mock_client.sid
|
44 |
+
environ = {'QUERY_STRING': 'client_id=test_client_id'}
|
45 |
+
await connect(sid, environ)
|
46 |
+
assert sid in clients
|
47 |
+
|
48 |
+
@pytest.mark.asyncio
|
49 |
+
async def test_disconnect(mock_client):
|
50 |
+
sid = mock_client.sid
|
51 |
+
clients[sid] = mock_client
|
52 |
+
await disconnect(sid)
|
53 |
+
assert sid not in clients
|
54 |
+
|
55 |
+
@pytest.mark.asyncio
|
56 |
+
async def test_target_language(mock_client):
|
57 |
+
sid = mock_client.sid
|
58 |
+
clients[sid] = mock_client
|
59 |
+
target_lang = "fr"
|
60 |
+
await target_language(sid, target_lang)
|
61 |
+
assert clients[sid].target_language == "fr"
|
62 |
+
|
63 |
+
# PM - issues with socketio enter_room in these tests
|
64 |
+
# @pytest.mark.asyncio
|
65 |
+
# async def test_call_user(mock_client):
|
66 |
+
# sid = mock_client.sid
|
67 |
+
# clients[sid] = mock_client
|
68 |
+
# call_id = "1234"
|
69 |
+
# await call_user(sid, call_id)
|
70 |
+
# assert call_id in rooms
|
71 |
+
# assert sid in rooms[call_id]
|
72 |
+
|
73 |
+
# @pytest.mark.asyncio
|
74 |
+
# async def test_answer_call(mock_client):
|
75 |
+
# sid = mock_client.sid
|
76 |
+
# clients[sid] = mock_client
|
77 |
+
# call_id = "1234"
|
78 |
+
# await answer_call(sid, call_id)
|
79 |
+
# assert call_id in rooms
|
80 |
+
# assert sid in rooms[call_id]
|
81 |
+
|
82 |
+
@pytest.mark.asyncio
|
83 |
+
async def test_incoming_audio(mock_client):
|
84 |
+
sid = mock_client.sid
|
85 |
+
clients[sid] = mock_client
|
86 |
+
data = b"\x01"
|
87 |
+
call_id = "1234"
|
88 |
+
await incoming_audio(sid, data, call_id)
|
89 |
+
assert clients[sid].get_length() != 0
|
90 |
+
|
backend/utils/__pycache__/text_rank.cpython-310.pyc
ADDED
Binary file (2.03 kB). View file
|
|
backend/utils/text_rank.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spacy
|
2 |
+
import pytextrank
|
3 |
+
from spacy.tokens import Span
|
4 |
+
|
5 |
+
# Define decorator for converting to singular version of words
|
6 |
+
@spacy.registry.misc("plural_scrubber")
|
7 |
+
def plural_scrubber():
|
8 |
+
def scrubber_func(span: Span) -> str:
|
9 |
+
return span.lemma_
|
10 |
+
return scrubber_func
|
11 |
+
|
12 |
+
|
13 |
+
def model_selector(target_language: str):
|
14 |
+
|
15 |
+
# Load subset of non-english models
|
16 |
+
language_model = {
|
17 |
+
"spa": "es_core_news_sm",
|
18 |
+
"fra": "fr_core_news_sm",
|
19 |
+
"pol": "pl_core_news_sm",
|
20 |
+
"deu": "de_core_news_sm",
|
21 |
+
"ita": "it_core_news_sm",
|
22 |
+
"por": "pt_core_news_sm",
|
23 |
+
"nld": "nl_core_news_sm",
|
24 |
+
"fin": "fi_core_news_sm",
|
25 |
+
"ron": "ro_core_news_sm",
|
26 |
+
"rus": "ru_core_news_sm"
|
27 |
+
}
|
28 |
+
|
29 |
+
try:
|
30 |
+
nlp = spacy.load(language_model[target_language])
|
31 |
+
|
32 |
+
except KeyError:
|
33 |
+
# Load a spaCy English model
|
34 |
+
nlp = spacy.load("en_core_web_lg")
|
35 |
+
|
36 |
+
# Add TextRank component to pipeline with stopwords
|
37 |
+
nlp.add_pipe("textrank", config={
|
38 |
+
"stopwords": {token:["NOUN"] for token in nlp.Defaults.stop_words},
|
39 |
+
"scrubber": {"@misc": "plural_scrubber"}})
|
40 |
+
|
41 |
+
return nlp
|
42 |
+
|
43 |
+
|
44 |
+
def extract_terms(text, target_language, length):
|
45 |
+
nlp = model_selector(target_language)
|
46 |
+
|
47 |
+
# Perform fact extraction on overall summary and segment summaries
|
48 |
+
doc = nlp(text)
|
49 |
+
|
50 |
+
if length < 100:
|
51 |
+
# Get single most used key term
|
52 |
+
phrases = {phrase.text for phrase in doc._.phrases[:1]}
|
53 |
+
elif length > 100 and length < 300:
|
54 |
+
# Create unique set from top 2 ranked phrases
|
55 |
+
phrases = {phrase.text for phrase in doc._.phrases[:2]}
|
56 |
+
if length > 300:
|
57 |
+
# Create unique set from top 3 ranked phrases
|
58 |
+
phrases = {phrase.text for phrase in doc._.phrases[:3]}
|
59 |
+
|
60 |
+
return list(phrases)
|