Spaces:
Sleeping
Sleeping
File size: 2,800 Bytes
89860e6 404b09b 89860e6 355fbaf 89860e6 355fbaf 89860e6 355fbaf 89860e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from typing import Optional
from fastapi import APIRouter
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
from datetime import datetime
from logger import log
from config import TEST_MODE
device = "cuda:0" if torch.cuda.is_available() else "cpu"
router = APIRouter()
class SentenceEmbeddingsInput(BaseModel):
inputs: list[str]
model: str
parameters: dict
class SentenceEmbeddingsOutput(BaseModel):
embeddings: Optional[list[list[float]]] = None
error: Optional[str] = None
@router.post('/sentence-embeddings')
def sentence_embeddings(inputs: SentenceEmbeddingsInput):
start_time = datetime.now()
fn = sentence_embeddings_mapping.get(inputs.model)
if not fn:
return SentenceEmbeddingsOutput(
error=f'No sentence embeddings model found for {inputs.model}'
)
try:
embeddings = fn(inputs.inputs, inputs.parameters)
log({
"task": "sentence_embeddings",
"model": inputs.model,
"start_time": start_time.isoformat(),
"time_taken": (datetime.now() - start_time).total_seconds(),
"inputs": inputs.inputs,
"outputs": embeddings,
"parameters": inputs.parameters,
})
loaded_models_last_updated[inputs.model] = datetime.now()
return SentenceEmbeddingsOutput(
embeddings=embeddings
)
except Exception as e:
return SentenceEmbeddingsOutput(
error=str(e)
)
def generic_sentence_embeddings(model_name: str):
global loaded_models
def process_texts(texts: list[str], parameters: dict):
if TEST_MODE:
return [[0.1,0.2]] * len(texts)
if model_name in loaded_models:
tokenizer, model = loaded_models[model_name]
else:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
loaded_models[model] = (tokenizer, model)
# Tokenize sentences
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device)
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embeddings = model_output[0][:, 0]
# normalize embeddings
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings.tolist()
return process_texts
# Polling every X minutes to
loaded_models = {}
loaded_models_last_updated = {}
sentence_embeddings_mapping = {
'BAAI/bge-base-en-v1.5': generic_sentence_embeddings('BAAI/bge-base-en-v1.5'),
'BAAI/bge-large-en-v1.5': generic_sentence_embeddings('BAAI/bge-large-en-v1.5'),
} |