embed-api / apps /app.py
Hansimov's picture
:zap: [Enhance] Use JinaAIOnnxEmbedder for better performance
20ba605
raw
history blame contribute delete
No virus
3.23 kB
import argparse
import markdown2
import sys
import uvicorn
from pathlib import Path
from typing import Union, Optional
from fastapi import FastAPI
from pydantic import BaseModel, Field
from fastapi.responses import HTMLResponse
from tclogger import logger, OSEnver
from transforms.embed import JinaAIEmbedder, JinaAIOnnxEmbedder
from configs.constants import AVAILABLE_MODELS
info_path = Path(__file__).parents[1] / "configs" / "info.json"
ENVER = OSEnver(info_path)
class EmbeddingApp:
def __init__(self):
self.app = FastAPI(
docs_url="/",
title=ENVER["app_name"],
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
version=ENVER["version"],
)
self.embedder = JinaAIOnnxEmbedder()
self.setup_routes()
def get_available_models(self):
return AVAILABLE_MODELS
def get_readme(self):
readme_path = Path(__file__).parents[1] / "README.md"
with open(readme_path, "r", encoding="utf-8") as rf:
readme_str = rf.read()
readme_html = markdown2.markdown(
readme_str, extras=["table", "fenced-code-blocks", "highlightjs-lang"]
)
return readme_html
class CalcEmbeddingPostItem(BaseModel):
text: Union[str, list[str]] = Field(
default=None,
summary="Input text(s) to embed",
)
model: Optional[str] = Field(
default=AVAILABLE_MODELS[0],
summary="Embedding model name",
)
def calc_embedding(self, item: CalcEmbeddingPostItem):
logger.note(f"> Encoding text: [{item.text}]", end=" ")
# if item.model != self.embedder.model:
# self.embedder.switch_model(item.model)
embeddings = self.embedder.encode(item.text).tolist()
logger.success(f"[{len(embeddings[0])}]")
if len(embeddings) == 1:
return embeddings[0]
else:
return embeddings
def setup_routes(self):
self.app.get(
"/models",
summary="Get available models",
)(self.get_available_models)
self.app.post(
"/embedding",
summary="Calculate embedding for input text(s)",
)(self.calc_embedding)
self.app.get(
"/readme",
summary="README of Embed API",
response_class=HTMLResponse,
include_in_schema=False,
)(self.get_readme)
class ArgParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
super(ArgParser, self).__init__(*args, **kwargs)
self.add_argument(
"-s",
"--server",
type=str,
default=ENVER["server"],
help=f"Server IP ({ENVER['server']}) for Embed API",
)
self.add_argument(
"-p",
"--port",
type=int,
default=ENVER["port"],
help=f"Server Port ({ENVER['port']}) for Embed API",
)
self.args = self.parse_args(sys.argv[1:])
app = EmbeddingApp().app
if __name__ == "__main__":
args = ArgParser().args
uvicorn.run("__main__:app", host=args.server, port=args.port)
# python -m apps.app