from typing import Optional from fastapi import APIRouter from pydantic import BaseModel from transformers import AutoTokenizer, AutoModel import torch from datetime import datetime from logger import log from hf_to_api.config import TEST_MODE router = APIRouter() class SentenceEmbeddingsInput(BaseModel): inputs: list[str] model: str parameters: dict class SentenceEmbeddingsOutput(BaseModel): embeddings: Optional[list[list[float]]] = None error: Optional[str] = None @router.post('/sentence-embeddings') def sentence_embeddings(inputs: SentenceEmbeddingsInput): start_time = datetime.now() fn = sentence_embeddings_mapping.get(inputs.model) if not fn: return SentenceEmbeddingsOutput( error=f'No sentence embeddings model found for {inputs.model}' ) try: embeddings = fn(inputs.inputs, inputs.parameters) log({ "task": "sentence_embeddings", "model": inputs.model, "start_time": start_time.isoformat(), "time_taken": (datetime.now() - start_time).total_seconds(), "inputs": inputs.inputs, "outputs": embeddings, "parameters": inputs.parameters, }) loaded_models_last_updated[inputs.model] = datetime.now() return SentenceEmbeddingsOutput( embeddings=embeddings ) except Exception as e: return SentenceEmbeddingsOutput( error=str(e) ) def generic_sentence_embeddings(model_name: str): global loaded_models def process_texts(texts: list[str], parameters: dict): if TEST_MODE: return [[0.1,0.2]] * len(texts) if model_name in loaded_models: tokenizer, model = loaded_models[model_name] else: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name) loaded_models[model] = (tokenizer, model) # Tokenize sentences encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt') with torch.no_grad(): model_output = model(**encoded_input) sentence_embeddings = model_output[0][:, 0] # normalize embeddings sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) return sentence_embeddings.tolist() return process_texts # Polling every X minutes to loaded_models = {} loaded_models_last_updated = {} sentence_embeddings_mapping = { 'BAAI/bge-base-en-v1.5': generic_sentence_embeddings('BAAI/bge-base-en-v1.5'), 'BAAI/bge-large-en-v1.5': generic_sentence_embeddings('BAAI/bge-large-en-v1.5'), }