File size: 2,800 Bytes
89860e6
 
 
 
 
 
 
404b09b
89860e6
355fbaf
89860e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355fbaf
89860e6
 
 
355fbaf
89860e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from typing import Optional
from fastapi import APIRouter
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
from datetime import datetime
from logger import log
from config import TEST_MODE

device = "cuda:0" if torch.cuda.is_available() else "cpu"
router = APIRouter()

class SentenceEmbeddingsInput(BaseModel):
    inputs: list[str]
    model: str
    parameters: dict

class SentenceEmbeddingsOutput(BaseModel):
    embeddings: Optional[list[list[float]]] = None
    error: Optional[str] = None

@router.post('/sentence-embeddings')
def sentence_embeddings(inputs: SentenceEmbeddingsInput):
    start_time = datetime.now()
    fn = sentence_embeddings_mapping.get(inputs.model)
    if not fn:
        return SentenceEmbeddingsOutput(
            error=f'No sentence embeddings model found for {inputs.model}'
        )
    
    try:
        embeddings = fn(inputs.inputs, inputs.parameters)
        
        log({
            "task": "sentence_embeddings",
            "model": inputs.model,
            "start_time": start_time.isoformat(),
            "time_taken": (datetime.now() - start_time).total_seconds(),
            "inputs": inputs.inputs,
            "outputs": embeddings,
            "parameters": inputs.parameters,
        })
        loaded_models_last_updated[inputs.model] = datetime.now()
        return SentenceEmbeddingsOutput(
            embeddings=embeddings
        )
    except Exception as e:
        return SentenceEmbeddingsOutput(
            error=str(e)
        )

def generic_sentence_embeddings(model_name: str):
    global loaded_models

    def process_texts(texts: list[str], parameters: dict):
        if TEST_MODE:
            return [[0.1,0.2]] * len(texts)
    
        if model_name in loaded_models:
            tokenizer, model = loaded_models[model_name]
        else:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = AutoModel.from_pretrained(model_name).to(device)
            loaded_models[model] = (tokenizer, model)

        # Tokenize sentences
        encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device)
        with torch.no_grad():
            model_output = model(**encoded_input)
            sentence_embeddings = model_output[0][:, 0]

        # normalize embeddings
        sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
        return sentence_embeddings.tolist()

    return process_texts

# Polling every X minutes to 
loaded_models = {}
loaded_models_last_updated = {}

sentence_embeddings_mapping = {
    'BAAI/bge-base-en-v1.5': generic_sentence_embeddings('BAAI/bge-base-en-v1.5'),
    'BAAI/bge-large-en-v1.5': generic_sentence_embeddings('BAAI/bge-large-en-v1.5'),
}