File size: 1,395 Bytes
57720fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from typing import Generic, List, Optional, TypeVar
from functools import partial
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI
import numpy
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse

MODEL = SentenceTransformer("all-mpnet-base-v2")

def cache(func):
    inner_cache = dict()
    def inner(sentences: List[str]):
        if len(sentences) == 0:
            return []
        not_in_cache = list(filter(lambda s: s not in inner_cache.keys(), sentences))
        if len(not_in_cache) > 0:
            processed_sentences = func(list(not_in_cache))
            for sentence, embedding in zip(not_in_cache, processed_sentences):
                inner_cache[sentence] = embedding
        return [inner_cache[s] for s in sentences]
    return inner

@cache
def _encode(sentences: List[str]):
    embeddings = MODEL.encode(sentences, normalize_embeddings=True, batch_size=2, show_progress_bar=True)
    array = [numpy.around(a, 3).tolist() for a in embeddings]
    return array

class EmbedReq(BaseModel):
    sentences: List[str]

app = FastAPI()

@app.post("/embed", response_class=ORJSONResponse)
def embed(embed: EmbedReq):
    return _encode(embed.sentences)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)