import argparse import uvicorn import sys import os import io from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration import time import json from typing import List import torch import logging import string import random import base64 from fastapi import FastAPI, Response from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, Field from sse_starlette.sse import EventSourceResponse from utils.logger import logger from networks.message_streamer import MessageStreamer from messagers.message_composer import MessageComposer from googletrans import Translator from io import BytesIO from gtts import gTTS from fastapi.middleware.cors import CORSMiddleware class ChatAPIApp: def __init__(self): self.app = FastAPI( docs_url="/", title="HuggingFace LLM API", swagger_ui_parameters={"defaultModelsExpandDepth": -1}, version="1.0", ) self.setup_routes() def get_available_langs(self): f = open('apis/lang_name.json', "r") self.available_models = json.loads(f.read()) return self.available_models class TranslateCompletionsPostItem(BaseModel): from_language: str = Field( default="auto", description="(str) `Detect`", ) to_language: str = Field( default="en", description="(str) `en`", ) input_text: str = Field( default="Hello", description="(str) `Text for translate`", ) def translate_completions(self, item: TranslateCompletionsPostItem): translator = Translator() f = open('apis/lang_name.json', "r") available_langs = json.loads(f.read()) from_lang = 'en' to_lang = 'en' for lang_item in available_langs: if item.to_language == lang_item['code']: to_lang = item.to_language break translated = translator.translate(item.input_text, dest=to_lang) item_response = { "from_language": translated.src, "to_language": translated.dest, "text": item.input_text, "translate": translated.text } json_compatible_item_data = jsonable_encoder(item_response) return JSONResponse(content=json_compatible_item_data) def translate_ai_completions(self, item: TranslateCompletionsPostItem): translator = Translator() f = open('apis/lang_name.json', "r") available_langs = json.loads(f.read()) from_lang = 'en' to_lang = 'en' for lang_item in available_langs: if item.to_language == lang_item['code']: to_lang = item.to_language if item.from_language == lang_item['code']: from_lang = item.from_language if to_lang == 'auto': to_lang = 'en' if from_lang == 'auto': from_lang = translator.detect(item.input_text).lang if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") logging.warning("GPU not found, using CPU, translation will be very slow.") time_start = time.time() #TRANSFORMERS_CACHE pretrained_model = "facebook/m2m100_1.2B" cache_dir = "models/" tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir) model = M2M100ForConditionalGeneration.from_pretrained( pretrained_model, cache_dir=cache_dir ).to(device) model.eval() tokenizer.src_lang = from_lang with torch.no_grad(): encoded_input = tokenizer(item.input_text, return_tensors="pt").to(device) generated_tokens = model.generate( **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(to_lang) ) translated_text = tokenizer.batch_decode( generated_tokens, skip_special_tokens=True )[0] time_end = time.time() translated = translated_text item_response = { "from_language": from_lang, "to_language": to_lang, "text": item.input_text, "translate": translated, "start": str(time_start), "end": str(time_end) } json_compatible_item_data = jsonable_encoder(item_response) return JSONResponse(content=json_compatible_item_data) class DetectLanguagePostItem(BaseModel): input_text: str = Field( default="Hello", description="(str) `Text for detection`", ) def detect_language(self, item: DetectLanguagePostItem): translator = Translator() detected = translator.detect(item.input_text) item_response = { "lang": detected.lang, "confidence": detected.confidence, } json_compatible_item_data = jsonable_encoder(item_response) return JSONResponse(content=json_compatible_item_data) class TTSPostItem(BaseModel): input_text: str = Field( default="Hello", description="(str) `Text for TTS`", ) from_language: str = Field( default="en", description="(str) `TTS language`", ) def text_to_speech(self, item: TTSPostItem): try: audioobj = gTTS(text = item.input_text, lang = item.from_language, slow = False) fileName = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(10)); fileName = fileName + ".mp3"; mp3_fp = BytesIO() #audioobj.save(fileName) #audioobj.write_to_fp(mp3_fp) #buffer = bytearray(mp3_fp.read()) #base64EncodedStr = base64.encodebytes(buffer) #mp3_fp.read() #return Response(content=mp3_fp.tell(), media_type="audio/mpeg") return StreamingResponse(audioobj.stream()) except: item_response = { "status": 400 } json_compatible_item_data = jsonable_encoder(item_response) return JSONResponse(content=json_compatible_item_data) def setup_routes(self): for prefix in ["", "/v1"]: self.app.get( prefix + "/langs", summary="Get available languages", )(self.get_available_langs) self.app.post( prefix + "/translate", summary="translate text", )(self.translate_completions) self.app.post( prefix + "/translate/ai", summary="translate text with ai", )(self.translate_ai_completions) self.app.post( prefix + "/detect", summary="detect language", )(self.detect_language) self.app.post( prefix + "/tts", summary="text to speech", )(self.text_to_speech) class ArgParser(argparse.ArgumentParser): def __init__(self, *args, **kwargs): super(ArgParser, self).__init__(*args, **kwargs) self.add_argument( "-s", "--server", type=str, default="0.0.0.0", help="Server IP for HF LLM Chat API", ) self.add_argument( "-p", "--port", type=int, default=23333, help="Server Port for HF LLM Chat API", ) self.add_argument( "-d", "--dev", default=False, action="store_true", help="Run in dev mode", ) self.args = self.parse_args(sys.argv[1:]) app = ChatAPIApp().app app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) if __name__ == "__main__": args = ArgParser().args if args.dev: uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True) else: uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False) # python -m apis.chat_api # [Docker] on product mode # python -m apis.chat_api -d # [Dev] on develop mode