Spaces:
Running
Running
import argparse | |
import uvicorn | |
import sys | |
import json | |
import string | |
import random | |
import base64 | |
from fastapi import FastAPI, Response | |
from fastapi.encoders import jsonable_encoder | |
from fastapi.responses import JSONResponse, StreamingResponse | |
from pydantic import BaseModel, Field | |
from sse_starlette.sse import EventSourceResponse | |
from utils.logger import logger | |
from networks.message_streamer import MessageStreamer | |
from messagers.message_composer import MessageComposer | |
from googletrans import Translator | |
from io import BytesIO | |
from gtts import gTTS | |
from fastapi.middleware.cors import CORSMiddleware | |
class ChatAPIApp: | |
def __init__(self): | |
self.app = FastAPI( | |
docs_url="/", | |
title="HuggingFace LLM API", | |
swagger_ui_parameters={"defaultModelsExpandDepth": -1}, | |
version="1.0", | |
) | |
self.setup_routes() | |
def get_available_models(self): | |
f = open('apis/lang_name.json', "r") | |
self.available_models = json.loads(f.read()) | |
return self.available_models | |
class ChatCompletionsPostItem(BaseModel): | |
from_language: str = Field( | |
default="auto", | |
description="(str) `Detect`", | |
) | |
to_language: str = Field( | |
default="en", | |
description="(str) `en`", | |
) | |
input_text: str = Field( | |
default="Hello", | |
description="(str) `Text for translate`", | |
) | |
def chat_completions(self, item: ChatCompletionsPostItem): | |
translator = Translator() | |
f = open('apis/lang_name.json', "r") | |
available_langs = json.loads(f.read()) | |
from_lang = 'en' | |
to_lang = 'en' | |
for lang_item in available_langs: | |
if item.to_language == lang_item['code']: | |
to_lang = item.to_language | |
break | |
translated = translator.translate(item.input_text, dest=to_lang) | |
item_response = { | |
"from_language": translated.src, | |
"to_language": translated.dest, | |
"text": item.input_text, | |
"translate": translated.text | |
} | |
json_compatible_item_data = jsonable_encoder(item_response) | |
return JSONResponse(content=json_compatible_item_data) | |
class DetectLanguagePostItem(BaseModel): | |
input_text: str = Field( | |
default="Hello", | |
description="(str) `Text for detection`", | |
) | |
def detect_language(self, item: DetectLanguagePostItem): | |
translator = Translator() | |
detected = translator.detect(item.input_text) | |
item_response = { | |
"lang": detected.lang, | |
"confidence": detected.confidence, | |
} | |
json_compatible_item_data = jsonable_encoder(item_response) | |
return JSONResponse(content=json_compatible_item_data) | |
class TTSPostItem(BaseModel): | |
input_text: str = Field( | |
default="Hello", | |
description="(str) `Text for TTS`", | |
) | |
from_language: str = Field( | |
default="en", | |
description="(str) `TTS language`", | |
) | |
def text_to_speech(self, item: TTSPostItem): | |
try: | |
audioobj = gTTS(text = item.input_text, lang = item.from_language, slow = False) | |
fileName = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(10)); | |
fileName = fileName + ".mp3"; | |
mp3_fp = BytesIO() | |
#audioobj.save(fileName) | |
#audioobj.write_to_fp(mp3_fp) | |
#buffer = bytearray(mp3_fp.read()) | |
#base64EncodedStr = base64.encodebytes(buffer) | |
#mp3_fp.read() | |
#return Response(content=mp3_fp.tell(), media_type="audio/mpeg") | |
return StreamingResponse(audioobj.stream()) | |
except: | |
item_response = { | |
"status": 400 | |
} | |
json_compatible_item_data = jsonable_encoder(item_response) | |
return JSONResponse(content=json_compatible_item_data) | |
def setup_routes(self): | |
for prefix in ["", "/v1"]: | |
self.app.get( | |
prefix + "/models", | |
summary="Get available languages", | |
)(self.get_available_models) | |
self.app.post( | |
prefix + "/translate", | |
summary="translate text", | |
)(self.chat_completions) | |
self.app.post( | |
prefix + "/detect", | |
summary="detect language", | |
)(self.detect_language) | |
self.app.post( | |
prefix + "/tts", | |
summary="text to speech", | |
)(self.text_to_speech) | |
class ArgParser(argparse.ArgumentParser): | |
def __init__(self, *args, **kwargs): | |
super(ArgParser, self).__init__(*args, **kwargs) | |
self.add_argument( | |
"-s", | |
"--server", | |
type=str, | |
default="0.0.0.0", | |
help="Server IP for HF LLM Chat API", | |
) | |
self.add_argument( | |
"-p", | |
"--port", | |
type=int, | |
default=23333, | |
help="Server Port for HF LLM Chat API", | |
) | |
self.add_argument( | |
"-d", | |
"--dev", | |
default=False, | |
action="store_true", | |
help="Run in dev mode", | |
) | |
self.args = self.parse_args(sys.argv[1:]) | |
app = ChatAPIApp().app | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
if __name__ == "__main__": | |
args = ArgParser().args | |
if args.dev: | |
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True) | |
else: | |
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False) | |
# python -m apis.chat_api # [Docker] on product mode | |
# python -m apis.chat_api -d # [Dev] on develop mode | |