radinhas's picture
Update apis/chat_api.py
b240a3f
raw
history blame
6.02 kB
import argparse
import uvicorn
import sys
import json
import string
import random
import base64
from fastapi import FastAPI, Response
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
from utils.logger import logger
from networks.message_streamer import MessageStreamer
from messagers.message_composer import MessageComposer
from googletrans import Translator
from io import BytesIO
from gtts import gTTS
from fastapi.middleware.cors import CORSMiddleware
class ChatAPIApp:
def __init__(self):
self.app = FastAPI(
docs_url="/",
title="HuggingFace LLM API",
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
version="1.0",
)
self.setup_routes()
def get_available_models(self):
f = open('apis/lang_name.json', "r")
self.available_models = json.loads(f.read())
return self.available_models
class ChatCompletionsPostItem(BaseModel):
from_language: str = Field(
default="auto",
description="(str) `Detect`",
)
to_language: str = Field(
default="en",
description="(str) `en`",
)
input_text: str = Field(
default="Hello",
description="(str) `Text for translate`",
)
def chat_completions(self, item: ChatCompletionsPostItem):
translator = Translator()
f = open('apis/lang_name.json', "r")
available_langs = json.loads(f.read())
from_lang = 'en'
to_lang = 'en'
for lang_item in available_langs:
if item.to_language == lang_item['code']:
to_lang = item.to_language
break
translated = translator.translate(item.input_text, dest=to_lang)
item_response = {
"from_language": translated.src,
"to_language": translated.dest,
"text": item.input_text,
"translate": translated.text
}
json_compatible_item_data = jsonable_encoder(item_response)
return JSONResponse(content=json_compatible_item_data)
class DetectLanguagePostItem(BaseModel):
input_text: str = Field(
default="Hello",
description="(str) `Text for detection`",
)
def detect_language(self, item: DetectLanguagePostItem):
translator = Translator()
detected = translator.detect(item.input_text)
item_response = {
"lang": detected.lang,
"confidence": detected.confidence,
}
json_compatible_item_data = jsonable_encoder(item_response)
return JSONResponse(content=json_compatible_item_data)
class TTSPostItem(BaseModel):
input_text: str = Field(
default="Hello",
description="(str) `Text for TTS`",
)
from_language: str = Field(
default="en",
description="(str) `TTS language`",
)
def text_to_speech(self, item: TTSPostItem):
try:
audioobj = gTTS(text = item.input_text, lang = item.from_language, slow = False)
fileName = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(10));
fileName = fileName + ".mp3";
mp3_fp = BytesIO()
#audioobj.save(fileName)
#audioobj.write_to_fp(mp3_fp)
#buffer = bytearray(mp3_fp.read())
#base64EncodedStr = base64.encodebytes(buffer)
#mp3_fp.read()
#return Response(content=mp3_fp.tell(), media_type="audio/mpeg")
return StreamingResponse(audioobj.stream())
except:
item_response = {
"status": 400
}
json_compatible_item_data = jsonable_encoder(item_response)
return JSONResponse(content=json_compatible_item_data)
def setup_routes(self):
for prefix in ["", "/v1"]:
self.app.get(
prefix + "/models",
summary="Get available languages",
)(self.get_available_models)
self.app.post(
prefix + "/translate",
summary="translate text",
)(self.chat_completions)
self.app.post(
prefix + "/detect",
summary="detect language",
)(self.detect_language)
self.app.post(
prefix + "/tts",
summary="text to speech",
)(self.text_to_speech)
class ArgParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
super(ArgParser, self).__init__(*args, **kwargs)
self.add_argument(
"-s",
"--server",
type=str,
default="0.0.0.0",
help="Server IP for HF LLM Chat API",
)
self.add_argument(
"-p",
"--port",
type=int,
default=23333,
help="Server Port for HF LLM Chat API",
)
self.add_argument(
"-d",
"--dev",
default=False,
action="store_true",
help="Run in dev mode",
)
self.args = self.parse_args(sys.argv[1:])
app = ChatAPIApp().app
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if __name__ == "__main__":
args = ArgParser().args
if args.dev:
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True)
else:
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)
# python -m apis.chat_api # [Docker] on product mode
# python -m apis.chat_api -d # [Dev] on develop mode