File size: 2,732 Bytes
b805057
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import Optional
from fastapi import APIRouter
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
from datetime import datetime
from config import TEST_MODE, device, log

router = APIRouter()

class SentenceEmbeddingsInput(BaseModel):
    inputs: list[str]
    model: str
    parameters: dict

class SentenceEmbeddingsOutput(BaseModel):
    embeddings: Optional[list[list[float]]] = None
    error: Optional[str] = None

@router.post('/sentence-embeddings')
def sentence_embeddings(inputs: SentenceEmbeddingsInput):
    start_time = datetime.now()
    fn = sentence_embeddings_mapping.get(inputs.model)
    if not fn:
        return SentenceEmbeddingsOutput(
            error=f'No sentence embeddings model found for {inputs.model}'
        )
    
    try:
        embeddings = fn(inputs.inputs, inputs.parameters)
        
        log({
            "task": "sentence_embeddings",
            "model": inputs.model,
            "start_time": start_time.isoformat(),
            "time_taken": (datetime.now() - start_time).total_seconds(),
            "inputs": inputs.inputs,
            "outputs": embeddings,
            "parameters": inputs.parameters,
        })
        loaded_models_last_updated[inputs.model] = datetime.now()
        return SentenceEmbeddingsOutput(
            embeddings=embeddings
        )
    except Exception as e:
        return SentenceEmbeddingsOutput(
            error=str(e)
        )

def generic_sentence_embeddings(model_name: str):
    global loaded_models

    def process_texts(texts: list[str], parameters: dict):
        if TEST_MODE:
            return [[0.1,0.2]] * len(texts)
    
        if model_name in loaded_models:
            tokenizer, model = loaded_models[model_name]
        else:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = AutoModel.from_pretrained(model_name).to(device)
            loaded_models[model] = (tokenizer, model)

        # Tokenize sentences
        encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device)
        with torch.no_grad():
            model_output = model(**encoded_input)
            sentence_embeddings = model_output[0][:, 0]

        # normalize embeddings
        sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
        return sentence_embeddings.tolist()

    return process_texts

# Polling every X minutes to 
loaded_models = {}
loaded_models_last_updated = {}

sentence_embeddings_mapping = {
    'BAAI/bge-base-en-v1.5': generic_sentence_embeddings('BAAI/bge-base-en-v1.5'),
    'BAAI/bge-large-en-v1.5': generic_sentence_embeddings('BAAI/bge-large-en-v1.5'),
}