fastapi_ai_endpoints / sentence_embeddings.py
jxtan's picture
Update sentence_embeddings.py
355fbaf verified
raw
history blame
2.8 kB
from typing import Optional
from fastapi import APIRouter
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
from datetime import datetime
from logger import log
from config import TEST_MODE
device = "cuda:0" if torch.cuda.is_available() else "cpu"
router = APIRouter()
class SentenceEmbeddingsInput(BaseModel):
inputs: list[str]
model: str
parameters: dict
class SentenceEmbeddingsOutput(BaseModel):
embeddings: Optional[list[list[float]]] = None
error: Optional[str] = None
@router.post('/sentence-embeddings')
def sentence_embeddings(inputs: SentenceEmbeddingsInput):
start_time = datetime.now()
fn = sentence_embeddings_mapping.get(inputs.model)
if not fn:
return SentenceEmbeddingsOutput(
error=f'No sentence embeddings model found for {inputs.model}'
)
try:
embeddings = fn(inputs.inputs, inputs.parameters)
log({
"task": "sentence_embeddings",
"model": inputs.model,
"start_time": start_time.isoformat(),
"time_taken": (datetime.now() - start_time).total_seconds(),
"inputs": inputs.inputs,
"outputs": embeddings,
"parameters": inputs.parameters,
})
loaded_models_last_updated[inputs.model] = datetime.now()
return SentenceEmbeddingsOutput(
embeddings=embeddings
)
except Exception as e:
return SentenceEmbeddingsOutput(
error=str(e)
)
def generic_sentence_embeddings(model_name: str):
global loaded_models
def process_texts(texts: list[str], parameters: dict):
if TEST_MODE:
return [[0.1,0.2]] * len(texts)
if model_name in loaded_models:
tokenizer, model = loaded_models[model_name]
else:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
loaded_models[model] = (tokenizer, model)
# Tokenize sentences
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device)
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embeddings = model_output[0][:, 0]
# normalize embeddings
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings.tolist()
return process_texts
# Polling every X minutes to
loaded_models = {}
loaded_models_last_updated = {}
sentence_embeddings_mapping = {
'BAAI/bge-base-en-v1.5': generic_sentence_embeddings('BAAI/bge-base-en-v1.5'),
'BAAI/bge-large-en-v1.5': generic_sentence_embeddings('BAAI/bge-large-en-v1.5'),
}