Spaces:
Paused
Paused
from operator import itemgetter | |
import os | |
from datetime import datetime | |
import uvicorn | |
from typing import Any, Optional, Tuple, Dict, TypedDict | |
from urllib import parse | |
from uuid import uuid4 | |
import logging | |
from fastapi.logger import logger as fastapi_logger | |
import sys | |
from fastapi import FastAPI | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi import APIRouter, Body, Request, status | |
from pymongo import MongoClient | |
from dotenv import dotenv_values | |
from routes import router as api_router | |
from contextlib import asynccontextmanager | |
import requests | |
from typing import List | |
from datetime import date | |
from mongodb.operations.calls import * | |
from mongodb.operations.users import * | |
from mongodb.models.calls import UserCall, UpdateCall | |
# from mongodb.endpoints.calls import * | |
from transformers import AutoProcessor, SeamlessM4Tv2Model | |
# from seamless_communication.inference import Translator | |
from Client import Client | |
import numpy as np | |
import torch | |
import socketio | |
# Configure logger | |
gunicorn_error_logger = logging.getLogger("gunicorn.error") | |
gunicorn_logger = logging.getLogger("gunicorn") | |
uvicorn_access_logger = logging.getLogger("uvicorn.access") | |
gunicorn_error_logger.propagate = True | |
gunicorn_logger.propagate = True | |
uvicorn_access_logger.propagate = True | |
uvicorn_access_logger.handlers = gunicorn_error_logger.handlers | |
fastapi_logger.handlers = gunicorn_error_logger.handlers | |
# sio is the main socket.io entrypoint | |
sio = socketio.AsyncServer( | |
async_mode="asgi", | |
cors_allowed_origins="*", | |
logger=gunicorn_logger, | |
engineio_logger=gunicorn_logger, | |
) | |
# sio.logger.setLevel(logging.DEBUG) | |
socketio_app = socketio.ASGIApp(sio) | |
# app.mount("/", socketio_app) | |
config = dotenv_values(".env") | |
# Read connection string from environment vars | |
# uri = os.environ['MONGODB_URI'] | |
# Read connection string from .env file | |
uri = config['MONGODB_URI'] | |
# MongoDB Connection Lifespan Events | |
async def lifespan(app: FastAPI): | |
# startup logic | |
app.mongodb_client = MongoClient(uri) | |
app.database = app.mongodb_client['IT-Cluster1'] #connect to interpretalk primary db | |
try: | |
app.mongodb_client.admin.command('ping') | |
print("MongoDB Connection Established...") | |
except Exception as e: | |
print(e) | |
yield | |
# shutdown logic | |
print("Closing MongoDB Connection...") | |
app.mongodb_client.close() | |
app = FastAPI(lifespan=lifespan, logger=gunicorn_logger) | |
# New CORS funcitonality | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # configured node app port | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
app.include_router(api_router) # include routers for user, calls and transcripts operations | |
DEBUG = True | |
ESCAPE_HATCH_SERVER_LOCK_RELEASE_NAME = "remove_server_lock" | |
TARGET_SAMPLING_RATE = 16000 | |
MAX_BYTES_BUFFER = 960_000 | |
print("") | |
print("") | |
print("=" * 18 + " Interpretalk is starting... " + "=" * 18) | |
############################################### | |
# Configure socketio server | |
############################################### | |
# TODO PM - change this to the actual path | |
# seamless remnant code | |
CLIENT_BUILD_PATH = "../streaming-react-app/dist/" | |
static_files = { | |
"/": CLIENT_BUILD_PATH, | |
"/assets/seamless-db6a2555.svg": { | |
"filename": CLIENT_BUILD_PATH + "assets/seamless-db6a2555.svg", | |
"content_type": "image/svg+xml", | |
}, | |
} | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large") | |
# PM - hardcoding temporarily as my GPU doesnt have enough vram | |
model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to(device) | |
bytes_data = bytearray() | |
model_name = "seamlessM4T_v2_large" | |
vocoder_name = "vocoder_v2" if model_name == "seamlessM4T_v2_large" else "vocoder_36langs" | |
clients = {} | |
rooms = {} | |
def get_collection_users(): | |
return app.database["user_records"] | |
def get_collection_calls(): | |
return app.database["call_records"] | |
def test(): | |
return {"message": "Welcome to InterpreTalk!"} | |
async def send_translated_text(client_id, username, original_text, translated_text, room_id): | |
# print(rooms) # Debugging | |
# print(clients) # Debugging | |
data = { | |
"author_id": str(client_id), | |
"author_username": str(client_id), | |
"original_text": str(original_text), | |
"translated_text": str(translated_text), | |
"timestamp": str(datetime.now()) | |
} | |
gunicorn_logger.info("SENDING TRANSLATED TEXT TO CLIENT") | |
await sio.emit("translated_text", data, room=room_id) | |
gunicorn_logger.info("SUCCESSFULLY SEND AUDIO TO FRONTEND") | |
async def connect(sid, environ): | |
print(f"📥 [event: connected] sid={sid}") | |
query_params = dict(parse.parse_qsl(environ["QUERY_STRING"])) | |
client_id = query_params.get("client_id") | |
gunicorn_logger.info(f"📥 [event: connected] sid={sid}, client_id={client_id}") | |
# get username to Client Object from DB | |
username = find_name_from_id(get_collection_users(), client_id) | |
# sid = socketid, client_id = client specific ID ,always the same for same user | |
clients[sid] = Client(sid, client_id, username) | |
print(clients[sid].username) | |
gunicorn_logger.warning(f"Client connected: {sid}") | |
gunicorn_logger.warning(clients) | |
async def disconnect(sid): | |
gunicorn_logger.debug(f"📤 [event: disconnected] sid={sid}") | |
call_id = clients[sid].call_id | |
user_id = clients[sid].client_id | |
target_language = clients[sid].target_language | |
clients.pop(sid, None) | |
# Perform Key Term Extraction and summarisation | |
try: | |
# Get combined caption field for call record based on call_id | |
key_terms = term_extraction(get_collection_calls(), call_id, user_id, target_language) | |
# Perform summarisation based on target language | |
summary_result = summarise(get_collection_calls(), call_id, user_id, target_language) | |
except: | |
gunicorn_logger.error(f"📤 [event: term_extraction/summarisation request error] sid={sid}, call={call_id}") | |
async def target_language(sid, target_lang): | |
gunicorn_logger.info(f"📥 [event: target_language] sid={sid}, target_lang={target_lang}") | |
clients[sid].target_language = target_lang | |
async def call_user(sid, call_id): | |
clients[sid].call_id = call_id | |
gunicorn_logger.info(f"CALL {sid}: entering room {call_id}") | |
rooms[call_id] = rooms.get(call_id, []) | |
if sid not in rooms[call_id] and len(rooms[call_id]) < 2: | |
rooms[call_id].append(sid) | |
sio.enter_room(sid, call_id) | |
else: | |
gunicorn_logger.info(f"CALL {sid}: room {call_id} is full") | |
# await sio.emit("room_full", room=call_id, to=sid) | |
# BO - Get call id from dictionary created during socketio connection | |
client_id = clients[sid].client_id | |
gunicorn_logger.warning(f"NOW TRYING TO CREATE DB RECORD FOR Caller with ID: {client_id} for call: {call_id}") | |
# BO -> Create Call Record with Caller and call_id field (None for callee, duration, terms..) | |
request_data = { | |
"call_id": str(call_id), | |
"caller_id": str(client_id), | |
"creation_date": str(datetime.now()) | |
} | |
response = create_calls(get_collection_calls(), request_data) | |
print(response) # BO - print created db call record | |
async def audio_config(sid, sample_rate): | |
clients[sid].original_sr = sample_rate | |
async def answer_call(sid, call_id): | |
clients[sid].call_id = call_id | |
gunicorn_logger.info(f"ANSWER {sid}: entering room {call_id}") | |
rooms[call_id] = rooms.get(call_id, []) | |
if sid not in rooms[call_id] and len(rooms[call_id]) < 2: | |
rooms[call_id].append(sid) | |
sio.enter_room(sid, call_id) | |
else: | |
gunicorn_logger.info(f"ANSWER {sid}: room {call_id} is full") | |
# await sio.emit("room_full", room=call_id, to=sid) | |
# BO - Get call id from dictionary created during socketio connection | |
client_id = clients[sid].client_id | |
# BO -> Update Call Record with Callee field based on call_id | |
gunicorn_logger.warning(f"NOW UPDATING MongoDB RECORD FOR Caller with ID: {client_id} for call: {call_id}") | |
# BO -> Create Call Record with callee_id field (None for callee, duration, terms..) | |
request_data = { | |
"callee_id": client_id | |
} | |
response = update_calls(get_collection_calls(), call_id, request_data) | |
print(response) # BO - print created db call record | |
async def incoming_audio(sid, data, call_id): | |
try: | |
clients[sid].add_bytes(data) | |
if clients[sid].get_length() >= MAX_BYTES_BUFFER: | |
gunicorn_logger.info('Buffer full, now outputting...') | |
output_path = clients[sid].output_path | |
resampled_audio = clients[sid].resample_and_clear() | |
vad_result = clients[sid].vad_analyse(resampled_audio) | |
# source lang is speakers tgt language 😃 | |
src_lang = clients[sid].target_language | |
if vad_result: | |
gunicorn_logger.info('Speech detected, now processing audio.....') | |
tgt_sid = next(id for id in rooms[call_id] if id != sid) | |
tgt_lang = clients[tgt_sid].target_language | |
# following example from https://github.com/facebookresearch/seamless_communication/blob/main/docs/m4t/README.md#transformers-usage | |
output_tokens = processor(audios=resampled_audio, src_lang=src_lang, return_tensors="pt", sampling_rate=TARGET_SAMPLING_RATE).to(device) | |
model_output = model.generate(**output_tokens, tgt_lang=src_lang, generate_speech=False)[0].tolist()[0] | |
asr_text = processor.decode(model_output, skip_special_tokens=True) | |
print(f"ASR TEXT = {asr_text}") | |
# ASR TEXT => ORIGINAL TEXT | |
if src_lang != tgt_lang: | |
t2t_tokens = processor(text=asr_text, src_lang=src_lang, tgt_lang=tgt_lang, return_tensors="pt").to(device) | |
translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0] | |
translated_text = processor.decode(translated_data, skip_special_tokens=True) | |
print(f"TRANSLATED TEXT = {translated_text}") | |
else: | |
# PM - both users have same language selected, no need to translate | |
translated_text = asr_text | |
# PM - text_output is a list with 1 string | |
await send_translated_text(clients[sid].client_id, clients[sid].username, asr_text, translated_text, call_id) | |
# BO -> send translated_text to mongodb as caption record update based on call_id | |
await send_captions(clients[sid].client_id, clients[sid].username, asr_text, translated_text, call_id) | |
except Exception as e: | |
gunicorn_logger.error(f"Error in incoming_audio: {e.with_traceback()}") | |
async def send_captions(client_id, username, original_text, translated_text, call_id): | |
# BO -> Update Call Record with Callee field based on call_id | |
print(f"Now updating Caption field in call record for Caller with ID: {client_id} for call: {call_id}") | |
data = { | |
"author_id": str(client_id), | |
"author_username": str(username), | |
"original_text": str(original_text), | |
"translated_text": str(translated_text), | |
"timestamp": str(datetime.now()) | |
} | |
response = update_captions(get_collection_calls(), get_collection_users(), call_id, data) | |
return response | |
app.mount("/", socketio_app) | |
if __name__ == '__main__': | |
uvicorn.run("main:app", host='0.0.0.0', port=7860, log_level="info") | |
# Running in Docker Container | |
if __name__ != "__main__": | |
fastapi_logger.setLevel(gunicorn_logger.level) | |
else: | |
fastapi_logger.setLevel(logging.DEBUG) | |