fastapi_ai_endpoints / tasks /translation.py
jxtan's picture
Added Translation Endpoint
b805057
raw
history blame
4.25 kB
from fastapi import APIRouter
from pydantic import BaseModel
from typing import Optional
from config import TEST_MODE, device, dtype, log
from fairseq2.data.text.text_tokenizer import TextTokenEncoder
from seamless_communication.inference import Translator
import spacy
import re
from datetime import datetime
router = APIRouter()
class TranslateInput(BaseModel):
inputs: list[str]
model: str
src_lang: str
dst_lang: str
class TranslateOutput(BaseModel):
src_lang: str
dst_lang: str
translations: Optional[list[str]] = None
error: Optional[str] = None
@router.post('/t2tt')
def t2tt(inputs: TranslateInput) -> TranslateOutput:
start_time = datetime.now()
fn = t2tt_mapping.get(inputs.model)
if not fn:
return TranslateOutput(
src_lang=inputs.src_lang,
dst_lang=inputs.dst_lang,
error=f'No sentence embeddings model found for {inputs.model}'
)
try:
translations = fn(**inputs.dict())
log({
"task": "sentence_embeddings",
"model": inputs.model,
"start_time": start_time.isoformat(),
"time_taken": (datetime.now() - start_time).total_seconds(),
"inputs": inputs.inputs,
"outputs": translations,
"parameters": {
"src_lang": inputs.src_lang,
"dst_lang": inputs.dst_lang,
},
})
loaded_models_last_updated[inputs.model] = datetime.now()
return TranslateOutput(**translations)
except Exception as e:
return TranslateOutput(
src_lang=inputs.src_lang,
dst_lang=inputs.dst_lang,
error=str(e)
)
cmn_nlp = spacy.load("zh_core_web_sm")
xx_nlp = spacy.load("xx_sent_ud_sm")
unk_re = re.compile(r"\s?<unk>|\s?⁇")
def seamless_t2tt(inputs: list[str], src_lang: str, dst_lang: str = 'eng'):
if TEST_MODE:
return {
"src_lang": src_lang,
"dst_lang": dst_lang,
"translations": None,
"error": None
}
# Load model
if 'facebook/seamless-m4t-v2-large' in loaded_models:
translator = loaded_models['facebook/seamless-m4t-v2-large']
else:
translator = Translator(
model_name_or_card="seamlessM4T_v2_large",
vocoder_name_or_card="vocoder_v2",
device=device,
dtype=dtype,
apply_mintox=False,
)
loaded_models['facebook/seamless-m4t-v2-large'] = translator
def sent_tokenize(text, lang) -> list[str]:
if lang == 'cmn':
return [str(t) for t in cmn_nlp(text).sents]
return [str(t) for t in xx_nlp(text).sents]
def tokenize_and_translate(token_encoder: TextTokenEncoder, text: str, src_lang: str, dst_lang: str) -> str:
# Convert text into paragraphs and replace new lines with spaces
lines = [sent_tokenize(line.replace("\n", " "), src_lang) for line in text.split('\n\n') if line]
lines = [item for sublist in lines for item in sublist if item]
# Tokenize and translate
input_tokens = translator.collate([token_encoder(line) for line in lines])
translations = [
unk_re.sub("", str(t))
for t in translator.predict(
input=input_tokens,
task_str="T2TT",
src_lang=src_lang,
tgt_lang=dst_lang,
)[0]
]
return " ".join(translations)
translations = None
token_encoder = translator.text_tokenizer.create_encoder(
task="translation", lang=src_lang, mode="source", device=translator.device
)
try:
translations = [tokenize_and_translate(token_encoder, text, src_lang, dst_lang) for text in inputs]
except Exception as e:
print(f"Error translating text: {e}")
return {
"src_lang": src_lang,
"dst_lang": dst_lang,
"translations": translations,
"error": None if translations else "Failed to translate text"
}
# Polling every X minutes to
loaded_models = {}
loaded_models_last_updated = {}
t2tt_mapping = {
'facebook/seamless-m4t-v2-large': seamless_t2tt,
}