import argparse import uvicorn import sys import json import string import random import base64 from fastapi import FastAPI, Response from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, Field from sse_starlette.sse import EventSourceResponse from utils.logger import logger from networks.message_streamer import MessageStreamer from messagers.message_composer import MessageComposer from googletrans import Translator from io import BytesIO from gtts import gTTS from fastapi.middleware.cors import CORSMiddleware class ChatAPIApp: def __init__(self): self.app = FastAPI( docs_url="/", title="HuggingFace LLM API", swagger_ui_parameters={"defaultModelsExpandDepth": -1}, version="1.0", ) self.setup_routes() def get_available_models(self): f = open('apis/lang_name.json', "r") self.available_models = json.loads(f.read()) return self.available_models class ChatCompletionsPostItem(BaseModel): from_language: str = Field( default="auto", description="(str) `Detect`", ) to_language: str = Field( default="en", description="(str) `en`", ) input_text: str = Field( default="Hello", description="(str) `Text for translate`", ) def chat_completions(self, item: ChatCompletionsPostItem): translator = Translator() f = open('apis/lang_name.json', "r") available_langs = json.loads(f.read()) from_lang = 'en' to_lang = 'en' for lang_item in available_langs: if item.to_language == lang_item['code']: to_lang = item.to_language break translated = translator.translate(item.input_text, dest=to_lang) item_response = { "from_language": translated.src, "to_language": translated.dest, "text": item.input_text, "translate": translated.text } json_compatible_item_data = jsonable_encoder(item_response) return JSONResponse(content=json_compatible_item_data) class DetectLanguagePostItem(BaseModel): input_text: str = Field( default="Hello", description="(str) `Text for detection`", ) def detect_language(self, item: DetectLanguagePostItem): translator = Translator() detected = translator.detect(item.input_text) item_response = { "lang": detected.lang, "confidence": detected.confidence, } json_compatible_item_data = jsonable_encoder(item_response) return JSONResponse(content=json_compatible_item_data) class TTSPostItem(BaseModel): input_text: str = Field( default="Hello", description="(str) `Text for TTS`", ) from_language: str = Field( default="en", description="(str) `TTS language`", ) def text_to_speech(self, item: TTSPostItem): try: audioobj = gTTS(text = item.input_text, lang = item.from_language, slow = False) fileName = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(10)); fileName = fileName + ".mp3"; mp3_fp = BytesIO() #audioobj.save(fileName) #audioobj.write_to_fp(mp3_fp) #buffer = bytearray(mp3_fp.read()) #base64EncodedStr = base64.encodebytes(buffer) #mp3_fp.read() #return Response(content=mp3_fp.tell(), media_type="audio/mpeg") return StreamingResponse(audioobj.stream()) except: item_response = { "status": 400 } json_compatible_item_data = jsonable_encoder(item_response) return JSONResponse(content=json_compatible_item_data) def setup_routes(self): for prefix in ["", "/v1"]: self.app.get( prefix + "/models", summary="Get available languages", )(self.get_available_models) self.app.post( prefix + "/translate", summary="translate text", )(self.chat_completions) self.app.post( prefix + "/detect", summary="detect language", )(self.detect_language) self.app.post( prefix + "/tts", summary="text to speech", )(self.text_to_speech) class ArgParser(argparse.ArgumentParser): def __init__(self, *args, **kwargs): super(ArgParser, self).__init__(*args, **kwargs) self.add_argument( "-s", "--server", type=str, default="0.0.0.0", help="Server IP for HF LLM Chat API", ) self.add_argument( "-p", "--port", type=int, default=23333, help="Server Port for HF LLM Chat API", ) self.add_argument( "-d", "--dev", default=False, action="store_true", help="Run in dev mode", ) self.args = self.parse_args(sys.argv[1:]) app = ChatAPIApp().app app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) if __name__ == "__main__": args = ArgParser().args if args.dev: uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True) else: uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False) # python -m apis.chat_api # [Docker] on product mode # python -m apis.chat_api -d # [Dev] on develop mode