Spaces:
Runtime error
Runtime error
from typing import Optional | |
from fastapi import APIRouter | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
from datetime import datetime | |
from logger import log | |
from config import TEST_MODE | |
router = APIRouter() | |
class SentenceEmbeddingsInput(BaseModel): | |
inputs: list[str] | |
model: str | |
parameters: dict | |
class SentenceEmbeddingsOutput(BaseModel): | |
embeddings: Optional[list[list[float]]] = None | |
error: Optional[str] = None | |
def sentence_embeddings(inputs: SentenceEmbeddingsInput): | |
start_time = datetime.now() | |
fn = sentence_embeddings_mapping.get(inputs.model) | |
if not fn: | |
return SentenceEmbeddingsOutput( | |
error=f'No sentence embeddings model found for {inputs.model}' | |
) | |
try: | |
embeddings = fn(inputs.inputs, inputs.parameters) | |
log({ | |
"task": "sentence_embeddings", | |
"model": inputs.model, | |
"start_time": start_time.isoformat(), | |
"time_taken": (datetime.now() - start_time).total_seconds(), | |
"inputs": inputs.inputs, | |
"outputs": embeddings, | |
"parameters": inputs.parameters, | |
}) | |
loaded_models_last_updated[inputs.model] = datetime.now() | |
return SentenceEmbeddingsOutput( | |
embeddings=embeddings | |
) | |
except Exception as e: | |
return SentenceEmbeddingsOutput( | |
error=str(e) | |
) | |
def generic_sentence_embeddings(model_name: str): | |
global loaded_models | |
def process_texts(texts: list[str], parameters: dict): | |
if TEST_MODE: | |
return [[0.1,0.2]] * len(texts) | |
if model_name in loaded_models: | |
tokenizer, model = loaded_models[model_name] | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name) | |
loaded_models[model] = (tokenizer, model) | |
# Tokenize sentences | |
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt') | |
with torch.no_grad(): | |
model_output = model(**encoded_input) | |
sentence_embeddings = model_output[0][:, 0] | |
# normalize embeddings | |
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) | |
return sentence_embeddings.tolist() | |
return process_texts | |
# Polling every X minutes to | |
loaded_models = {} | |
loaded_models_last_updated = {} | |
sentence_embeddings_mapping = { | |
'BAAI/bge-base-en-v1.5': generic_sentence_embeddings('BAAI/bge-base-en-v1.5'), | |
'BAAI/bge-large-en-v1.5': generic_sentence_embeddings('BAAI/bge-large-en-v1.5'), | |
} |