File size: 2,864 Bytes
a342b03
 
 
 
 
6949114
510fde2
 
 
 
 
 
 
 
a342b03
 
f89c622
a342b03
 
6949114
a342b03
 
 
6949114
 
 
a342b03
6949114
a342b03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d7b683
 
 
 
 
 
510fde2
 
 
 
 
 
 
 
 
 
a342b03
6949114
3d7b683
 
a342b03
 
6949114
a342b03
6949114
510fde2
6949114
510fde2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import hnswlib
import numpy as np
import datetime
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List

if torch.cuda.is_available():
    print("CUDA is available! Inference on GPU!")
else:
    print("CUDA is not available. Inference on CPU.")

seperator = "-HFSEP-"
base_name="intfloat/e5-small-v2"
device="cuda"
max_length=512
max_batch_size = 500
tokenizer = AutoTokenizer.from_pretrained(base_name)
model = AutoModel.from_pretrained(base_name).to(device)

def current_timestamp():
    return datetime.datetime.utcnow().timestamp()

def get_embeddings(input_texts):
    input_texts = input_texts[:max_batch_size]
    batch_dict = tokenizer(
        input_texts,
        max_length=max_length,
        padding=True,
        truncation=True,
        return_tensors='pt'
    ).to(device)

    with torch.no_grad():
        outputs = model(**batch_dict)
        embeddings = _average_pool(
            outputs.last_hidden_state, batch_dict['attention_mask']
        )
        embeddings = F.normalize(embeddings, p=2, dim=1)
        embeddings_np = embeddings.cpu().numpy()
        
        if device == "cuda":
            del embeddings
            torch.cuda.empty_cache()        
        
        return embeddings_np

def _average_pool(
    last_hidden_states,
    attention_mask
):
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

def create_hnsw_index(embeddings_np, space='ip', ef_construction=100, M=16):
    index = hnswlib.Index(space=space, dim=len(embeddings_np[0]))
    index.init_index(max_elements=len(embeddings_np), ef_construction=ef_construction, M=M)
    ids = np.arange(embeddings_np.shape[0])
    index.add_items(embeddings_np, ids)
    return index

def preprocess_texts(query, paragraphs):
    query = f'query: {query}'
    paragraphs = [f'passage: {p}' for p in paragraphs]
    return [query]+paragraphs


app = FastAPI()

class EmbeddingsSimilarityReq(BaseModel):
    paragraphs: List[str]
    query: str
    top_k: int

@app.post("/")
async def find_similar_paragraphsitem(req: EmbeddingsSimilarityReq):
    print("Len of batches", len(req.paragraphs))

    print("creating embeddings", current_timestamp())
    inputs = preprocess_texts(req.query, req.paragraphs)
    embeddings_np = get_embeddings(inputs)
    query_embedding, chunks_embeddings = embeddings_np[0], embeddings_np[1:]
    
    print("creating index", current_timestamp())
    search_index = create_hnsw_index(chunks_embeddings)
    print("searching index", current_timestamp())
    labels, _ = search_index.knn_query(query_embedding, k=min(int(req.top_k), len(chunks_embeddings)))
    labels = labels[0].tolist()
    return labels