blobba's picture
Duplicate from blobba/sentence-transformers-test-4
57720fe
raw
history blame
1.4 kB
from typing import Generic, List, Optional, TypeVar
from functools import partial
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI
import numpy
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse
MODEL = SentenceTransformer("all-mpnet-base-v2")
def cache(func):
inner_cache = dict()
def inner(sentences: List[str]):
if len(sentences) == 0:
return []
not_in_cache = list(filter(lambda s: s not in inner_cache.keys(), sentences))
if len(not_in_cache) > 0:
processed_sentences = func(list(not_in_cache))
for sentence, embedding in zip(not_in_cache, processed_sentences):
inner_cache[sentence] = embedding
return [inner_cache[s] for s in sentences]
return inner
@cache
def _encode(sentences: List[str]):
embeddings = MODEL.encode(sentences, normalize_embeddings=True, batch_size=2, show_progress_bar=True)
array = [numpy.around(a, 3).tolist() for a in embeddings]
return array
class EmbedReq(BaseModel):
sentences: List[str]
app = FastAPI()
@app.post("/embed", response_class=ORJSONResponse)
def embed(embed: EmbedReq):
return _encode(embed.sentences)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)