Spaces:
Runtime error
Runtime error
from fastapi import APIRouter | |
from pydantic import BaseModel | |
from typing import Optional | |
from config import TEST_MODE, device, dtype, log | |
from fairseq2.data.text.text_tokenizer import TextTokenEncoder | |
from seamless_communication.inference import Translator | |
import spacy | |
import re | |
from datetime import datetime | |
router = APIRouter() | |
class TranslateInput(BaseModel): | |
inputs: list[str] | |
model: str | |
src_lang: str | |
dst_lang: str | |
class TranslateOutput(BaseModel): | |
src_lang: str | |
dst_lang: str | |
translations: Optional[list[str]] = None | |
error: Optional[str] = None | |
def t2tt(inputs: TranslateInput) -> TranslateOutput: | |
start_time = datetime.now() | |
fn = t2tt_mapping.get(inputs.model) | |
if not fn: | |
return TranslateOutput( | |
src_lang=inputs.src_lang, | |
dst_lang=inputs.dst_lang, | |
error=f'No sentence embeddings model found for {inputs.model}' | |
) | |
try: | |
translations = fn(**inputs.dict()) | |
log({ | |
"task": "sentence_embeddings", | |
"model": inputs.model, | |
"start_time": start_time.isoformat(), | |
"time_taken": (datetime.now() - start_time).total_seconds(), | |
"inputs": inputs.inputs, | |
"outputs": translations, | |
"parameters": { | |
"src_lang": inputs.src_lang, | |
"dst_lang": inputs.dst_lang, | |
}, | |
}) | |
loaded_models_last_updated[inputs.model] = datetime.now() | |
return TranslateOutput(**translations) | |
except Exception as e: | |
return TranslateOutput( | |
src_lang=inputs.src_lang, | |
dst_lang=inputs.dst_lang, | |
error=str(e) | |
) | |
cmn_nlp = spacy.load("zh_core_web_sm") | |
xx_nlp = spacy.load("xx_sent_ud_sm") | |
unk_re = re.compile(r"\s?<unk>|\s?⁇") | |
def seamless_t2tt(inputs: list[str], src_lang: str, dst_lang: str = 'eng'): | |
if TEST_MODE: | |
return { | |
"src_lang": src_lang, | |
"dst_lang": dst_lang, | |
"translations": None, | |
"error": None | |
} | |
# Load model | |
if 'facebook/seamless-m4t-v2-large' in loaded_models: | |
translator = loaded_models['facebook/seamless-m4t-v2-large'] | |
else: | |
translator = Translator( | |
model_name_or_card="seamlessM4T_v2_large", | |
vocoder_name_or_card="vocoder_v2", | |
device=device, | |
dtype=dtype, | |
apply_mintox=False, | |
) | |
loaded_models['facebook/seamless-m4t-v2-large'] = translator | |
def sent_tokenize(text, lang) -> list[str]: | |
if lang == 'cmn': | |
return [str(t) for t in cmn_nlp(text).sents] | |
return [str(t) for t in xx_nlp(text).sents] | |
def tokenize_and_translate(token_encoder: TextTokenEncoder, text: str, src_lang: str, dst_lang: str) -> str: | |
# Convert text into paragraphs and replace new lines with spaces | |
lines = [sent_tokenize(line.replace("\n", " "), src_lang) for line in text.split('\n\n') if line] | |
lines = [item for sublist in lines for item in sublist if item] | |
# Tokenize and translate | |
input_tokens = translator.collate([token_encoder(line) for line in lines]) | |
translations = [ | |
unk_re.sub("", str(t)) | |
for t in translator.predict( | |
input=input_tokens, | |
task_str="T2TT", | |
src_lang=src_lang, | |
tgt_lang=dst_lang, | |
)[0] | |
] | |
return " ".join(translations) | |
translations = None | |
token_encoder = translator.text_tokenizer.create_encoder( | |
task="translation", lang=src_lang, mode="source", device=translator.device | |
) | |
try: | |
translations = [tokenize_and_translate(token_encoder, text, src_lang, dst_lang) for text in inputs] | |
except Exception as e: | |
print(f"Error translating text: {e}") | |
return { | |
"src_lang": src_lang, | |
"dst_lang": dst_lang, | |
"translations": translations, | |
"error": None if translations else "Failed to translate text" | |
} | |
# Polling every X minutes to | |
loaded_models = {} | |
loaded_models_last_updated = {} | |
t2tt_mapping = { | |
'facebook/seamless-m4t-v2-large': seamless_t2tt, | |
} |