from fastapi import APIRouter from pydantic import BaseModel from typing import Optional from config import TEST_MODE, device, dtype, log from fairseq2.data.text.text_tokenizer import TextTokenEncoder from seamless_communication.inference import Translator import spacy import re from datetime import datetime router = APIRouter() class TranslateInput(BaseModel): inputs: list[str] model: str src_lang: str dst_lang: str class TranslateOutput(BaseModel): src_lang: str dst_lang: str translations: Optional[list[str]] = None error: Optional[str] = None @router.post('/t2tt') def t2tt(inputs: TranslateInput) -> TranslateOutput: start_time = datetime.now() fn = t2tt_mapping.get(inputs.model) if not fn: return TranslateOutput( src_lang=inputs.src_lang, dst_lang=inputs.dst_lang, error=f'No sentence embeddings model found for {inputs.model}' ) try: translations = fn(**inputs.dict()) log({ "task": "sentence_embeddings", "model": inputs.model, "start_time": start_time.isoformat(), "time_taken": (datetime.now() - start_time).total_seconds(), "inputs": inputs.inputs, "outputs": translations, "parameters": { "src_lang": inputs.src_lang, "dst_lang": inputs.dst_lang, }, }) loaded_models_last_updated[inputs.model] = datetime.now() return TranslateOutput(**translations) except Exception as e: return TranslateOutput( src_lang=inputs.src_lang, dst_lang=inputs.dst_lang, error=str(e) ) cmn_nlp = spacy.load("zh_core_web_sm") xx_nlp = spacy.load("xx_sent_ud_sm") unk_re = re.compile(r"\s?|\s?⁇") def seamless_t2tt(inputs: list[str], src_lang: str, dst_lang: str = 'eng'): if TEST_MODE: return { "src_lang": src_lang, "dst_lang": dst_lang, "translations": None, "error": None } # Load model if 'facebook/seamless-m4t-v2-large' in loaded_models: translator = loaded_models['facebook/seamless-m4t-v2-large'] else: translator = Translator( model_name_or_card="seamlessM4T_v2_large", vocoder_name_or_card="vocoder_v2", device=device, dtype=dtype, apply_mintox=False, ) loaded_models['facebook/seamless-m4t-v2-large'] = translator def sent_tokenize(text, lang) -> list[str]: if lang == 'cmn': return [str(t) for t in cmn_nlp(text).sents] return [str(t) for t in xx_nlp(text).sents] def tokenize_and_translate(token_encoder: TextTokenEncoder, text: str, src_lang: str, dst_lang: str) -> str: # Convert text into paragraphs and replace new lines with spaces lines = [sent_tokenize(line.replace("\n", " "), src_lang) for line in text.split('\n\n') if line] lines = [item for sublist in lines for item in sublist if item] # Tokenize and translate input_tokens = translator.collate([token_encoder(line) for line in lines]) translations = [ unk_re.sub("", str(t)) for t in translator.predict( input=input_tokens, task_str="T2TT", src_lang=src_lang, tgt_lang=dst_lang, )[0] ] return " ".join(translations) translations = None token_encoder = translator.text_tokenizer.create_encoder( task="translation", lang=src_lang, mode="source", device=translator.device ) try: translations = [tokenize_and_translate(token_encoder, text, src_lang, dst_lang) for text in inputs] except Exception as e: print(f"Error translating text: {e}") return { "src_lang": src_lang, "dst_lang": dst_lang, "translations": translations, "error": None if translations else "Failed to translate text" } # Polling every X minutes to loaded_models = {} loaded_models_last_updated = {} t2tt_mapping = { 'facebook/seamless-m4t-v2-large': seamless_t2tt, }